diff --git a/README.md b/README.md index edfbcf59..f57a0e16 100644 --- a/README.md +++ b/README.md @@ -50,11 +50,17 @@ val d = c.relu() ### GGUF Model Loading ```kotlin -val source = SystemFileSystem.source(Path("model.gguf")).buffered() -val reader = GGUFReader(source) - -val tensor = reader.tensors.first { it.name == "token_embd.weight" } -val weights = reader.materialize(tensor) +// Recommended: streaming reader — memory-efficient, supports quantized types +val source = JvmRandomAccessSource.open("model.gguf") +StreamingGGUFReader.open(source).use { reader -> + println("Tensors: ${reader.tensorCount}") + + // Load specific tensor on demand (no whole-file loading) + val bytes = reader.loadTensor("token_embd.weight") + + // Or get a TensorStorage descriptor with encoding/placement metadata + val storage = reader.loadTensorStorage("token_embd.weight") +} ``` > **More examples:** [SKaiNET-examples](https://github.com/SKaiNET-developers/SKaiNET-examples) | [SKaiNET-notebook](https://github.com/SKaiNET-developers/SKaiNET-notebook) diff --git a/TURBOQUANT_ISSUES.md b/TURBOQUANT_ISSUES.md new file mode 100644 index 00000000..b12f37eb --- /dev/null +++ b/TURBOQUANT_ISSUES.md @@ -0,0 +1,582 @@ +# TurboQuant Implementation Tracker + +> Auto-generated from `prd.md` analysis on 2026-04-08. +> Branch: `feature/turboquant` + +## Legend + +| Symbol | Meaning | +|--------|---------| +| DONE | Implemented and tested | +| IN PROGRESS | Partially implemented | +| TODO | Not started | + +--- + +## Step 1: SKaiNET Core Preparation (PRD sections 1-6) + +### Completed + +- [x] **Storage & placement abstractions** — `TensorStorage`, `TensorEncoding`, `BufferHandle`, `Placement`, `LogicalDType` +- [x] **Zero-copy & ownership semantics** — Owned, Borrowed, Aliased, FileBacked, DeviceResident +- [x] **Packed quant unification** — `PackedBlockStorage` contract with Q4_K, Q8_0, Ternary +- [x] **Streaming GGUF loader** — `StreamingGGUFReader` + `StreamingGgufParametersLoader` +- [x] **Memory planning & tracking** — `MemoryPlanner`, `MemoryTracker`, `ActiveMemoryTracker` +- [x] **Transfer & materialization APIs** — `copyMaterialize()`, `copyToHost()`, `copyToDevice()` +- [x] **DSL annotations** — `@Place`, `@Weights` +- [x] **Benchmarks** — `StorageBenchmarks.kt` (Q4_K, Q8_0, Ternary dequant throughput) +- [x] **Acceptance criteria tests** — `AcceptanceCriteriaTest.kt` + +- [x] **KV-cache subsystem** — `KvCacheStore` interface, `DefaultKvCacheStore`, `KvCacheConfig`, `KvCacheMemoryReport` +- [x] **SDPA compressed K/V bridge** — `CompressedKvAttention` with dequant-on-read and raw storage paths +- [x] **Quants.kt port complete** — `byteShapeToQuantShape`, `quantByteSize`, `isBlockQuantized`, `validateQuantizedBytes` +- [x] **SafeTensors zero-copy loading** — `StorageAwareSafeTensorsLoader` with file-backed and borrowed modes + +### Remaining — None (Step 1 complete) + +--- + +### TQ-001: KV-Cache Subsystem + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 1, Requirement 4 | +| **Priority** | High — blocks all Step 2 work | +| **Dependencies** | None (Step 1 foundations complete) | + +**Description:** +Create a `KvCacheStore` abstraction that supports append-by-token writes, layer/head addressing, compressed K/V block storage, backend-specific read/dequant flows, and asymmetric K/V policies. + +**Acceptance criteria:** +- [ ] `KvCacheStore` interface defined with append, read, and eviction APIs +- [ ] Layer and head indexing supported +- [ ] Storage accepts any `TensorEncoding` (including future TurboQuant) +- [ ] Backend-specific dequant dispatch is extensible +- [ ] Asymmetric K/V encoding policies configurable per layer +- [ ] Unit tests for append, read, eviction, and multi-head addressing + +**Key files to create/modify:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/KvCacheStore.kt` (new) +- Tests in `skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/` + +--- + +### TQ-002: SDPA Integration for Compressed K/V + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 1, Requirement 5 | +| **Priority** | High — blocks TurboQuant SDPA path | +| **Dependencies** | TQ-001 | + +**Description:** +Extend `scaledDotProductAttention()` in `TensorOps.kt` to detect compressed K/V from `KvCacheStore`, decompress only the needed tiles on read, and provide a seam for fused dequant+attention. + +**Acceptance criteria:** +- [ ] SDPA detects `TensorEncoding` of K/V inputs +- [ ] Compressed K/V triggers dequant-on-read path +- [ ] Only required tiles/blocks are decompressed (not full cache) +- [ ] Extension point exists for backend-fused kernels +- [ ] Tests with Q4_K and Q8_0 encoded K/V (as proxies before TurboQuant) + +**Key files to modify:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt` + +--- + +### TQ-003: Complete Quants.kt Port + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 1, Requirement 6 | +| **Priority** | Medium | +| **Dependencies** | None | + +**Description:** +Complete the Python-to-Kotlin port of `Quants.kt` and `Constants.kt`. Added `byteShapeToQuantShape`, `quantElementCount`, `quantByteSize`, `isBlockQuantized`, `quantBlockSize`, `quantTypeSize`, `validateQuantizedBytes`. Removed stale TODO from `Constants.kt`. + +**Acceptance criteria:** +- [ ] All quantization types from llama.cpp `quants.py` are ported +- [ ] Multi-dimension shape utilities work correctly +- [ ] `Constants.kt` port complete +- [ ] Unit tests for each ported quant type + +**Key files to modify:** +- `skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Quants.kt` +- `skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Constants.kt` + +--- + +### TQ-004: SafeTensors Zero-Copy / Mapped Loading + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 1, Requirement 6 | +| **Priority** | Medium | +| **Dependencies** | None | + +**Description:** +Allow SafeTensors loaders to wrap or map buffers instead of always converting to dense arrays. Should produce `TensorStorage` with `FileBacked` or `Borrowed` buffer handles where possible. + +**Acceptance criteria:** +- [ ] SafeTensors loader can produce `TensorStorage` with `FileBacked` handles +- [ ] No unnecessary heap copy for read-only weight access +- [ ] Falls back to `Owned` copy when mutation is required +- [ ] Integration test with real `.safetensors` file + +**Key files to modify:** +- `skainet-io/skainet-io-safetensors/` (loader implementation) + +--- + +## Step 2: TurboQuant Introduction (PRD sections 1-5) + +### Completed + +- [x] **TQ-010: TurboQuant encoding types** — `TurboQuantPolar`, `TurboQuantPolarQjl` in `TensorEncoding` +- [x] **TQ-011: Random rotation kernel** — `RandomRotation` with Walsh-Hadamard + sign flips +- [x] **TQ-012: Scalar quantizer** — `ScalarQuantizer` with per-group scales, 2/3/4/8-bit +- [x] **TQ-013: QJL residual** — `QjlResidual` with 1-4 bit residual encoding +- [x] **TQ-014: Bit-packing** — `BitPacker` for 2/3/4/8-bit codes +- [x] **TQ-015: KV block APIs** — `TurboQuantCodec` encode/decode + `TurboQuantKvCacheStore` +- [x] **TQ-016: PolarOnly e2e** — Full pipeline: rotation → quant → pack → unpack → dequant → inverse rotation +- [x] **TQ-017+018: SDPA write/read** — `CompressedKvAttention` + `TurboQuantKvCacheStore` integration +- [x] **TQ-019: Role-aware K/V policies** — Asymmetric key/value configs in `TurboQuantKvCacheStore` +- [x] **TQ-020: Presets** — `TurboQuantPresets` with safe-lowbit, balanced, experimental-max + +- [x] **TQ-021: DSL/annotation support** — `@KvCache`, `@KvCacheBypass` annotations +- [x] **TQ-022: CPU SIMD optimization** — `JvmTurboQuantKernels` with Java Vector API +- [x] **TQ-025: JMH benchmarks** — Encode/decode/pack/rotate/KV cache benchmarks + +### Remaining + +- [ ] **TQ-023: Metal/Apple Silicon backend** — Requires Metal shader development +- [ ] **TQ-024: Fused dequant+attention kernels** — Depends on TQ-023 + +--- + +### TQ-010: TurboQuant Encoding Types + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Product definition | +| **Priority** | High — blocks all TurboQuant kernels | +| **Dependencies** | None | + +**Description:** +Add TurboQuant variants to the sealed `TensorEncoding` hierarchy: `TurboQuantPolar` (PolarOnly) and `TurboQuantPolarQjl` (PolarPlusQjl), with configurable bit budgets and block sizes. + +**Acceptance criteria:** +- [ ] `TurboQuantPolar` encoding added to `TensorEncoding` +- [ ] `TurboQuantPolarQjl` encoding added to `TensorEncoding` +- [ ] Configurable: bits per element, block size, codebook variant +- [ ] `bytesPerBlock` / `elementsPerBlock` computed correctly +- [ ] Exhaustive `when` coverage in existing code updated + +**Key files to modify:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorEncoding.kt` + +--- + +### TQ-011: Random Rotation Kernel + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 1 | +| **Priority** | High | +| **Dependencies** | TQ-010 | + +**Description:** +Implement random rotation generation in common Kotlin. This is the first stage of the TurboQuant pipeline — rotating input vectors before quantization. + +**Acceptance criteria:** +- [ ] Deterministic random rotation matrix generation (seeded) +- [ ] Correct orthogonality properties verified +- [ ] Works for arbitrary head dimensions +- [ ] Common Kotlin (no platform-specific code) +- [ ] Unit tests verifying rotation properties + +**Key files to create:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/` (new package) + +--- + +### TQ-012: Scalar Quantization / Codebook Lookup + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 1 | +| **Priority** | High | +| **Dependencies** | TQ-011 | + +**Description:** +Implement scalar quantization with codebook lookup for the rotated vectors. Supports configurable bit widths (2, 3, 4, 8). + +**Acceptance criteria:** +- [ ] Quantize rotated vector to N-bit codes +- [ ] Codebook lookup for dequantization +- [ ] Supports 2-bit, 3-bit, 4-bit, and 8-bit configurations +- [ ] Round-trip error within expected bounds per bit width +- [ ] Unit tests with known reference vectors + +**Key files to create:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/ScalarQuantizer.kt` (new) + +--- + +### TQ-013: QJL Residual Stage + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 1 | +| **Priority** | Medium — only needed for PolarPlusQjl variant | +| **Dependencies** | TQ-012 | + +**Description:** +Implement the QJL (Quantized Johnson-Lindenstrauss) residual stage for the PolarPlusQjl variant. This preserves inner-product accuracy by capturing quantization residuals. + +**Acceptance criteria:** +- [ ] QJL projection of quantization residual +- [ ] Inner-product error reduction verified vs PolarOnly +- [ ] Configurable residual bit budget +- [ ] Can be disabled (for PolarOnly path) +- [ ] Unit tests comparing IP accuracy with/without QJL + +**Key files to create:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/QjlResidual.kt` (new) + +--- + +### TQ-014: Bit-Packing Kernel + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 1 | +| **Priority** | High | +| **Dependencies** | TQ-012 | + +**Description:** +Implement bit-packing/unpacking for TurboQuant codes into compact byte arrays. Must support 2, 3, 4, and 8-bit packing. + +**Acceptance criteria:** +- [ ] Pack N-bit codes into byte arrays +- [ ] Unpack byte arrays back to codes +- [ ] Round-trip correctness for all supported bit widths +- [ ] Append-friendly (can pack incrementally per token) +- [ ] Unit tests for boundary conditions and all bit widths + +**Key files to create:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/BitPacker.kt` (new) + +--- + +### TQ-015: KV Block Append/Read APIs + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 1 | +| **Priority** | High | +| **Dependencies** | TQ-001, TQ-014 | + +**Description:** +Implement append and read APIs that connect TurboQuant encoding/decoding to the `KvCacheStore`. New tokens are compressed on write; stored blocks are decompressed on read. + +**Acceptance criteria:** +- [ ] Append single token's K/V as TurboQuant-compressed block +- [ ] Read and decompress arbitrary range of cached tokens +- [ ] Supports both PolarOnly and PolarPlusQjl paths +- [ ] Memory-efficient (no full cache decompression) +- [ ] Integration test: append N tokens, read back, verify accuracy + +**Key files to create/modify:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantKvCodec.kt` (new) +- Integrates with `KvCacheStore` from TQ-001 + +--- + +### TQ-016: PolarOnly Variant Implementation + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Supported variants | +| **Priority** | High — primary production variant | +| **Dependencies** | TQ-011, TQ-012, TQ-014, TQ-015 | + +**Description:** +Wire together rotation + scalar quantization + bit-packing into the complete PolarOnly end-to-end path. This is the backend-friendly variant without QJL. + +**Acceptance criteria:** +- [ ] End-to-end: float vector in -> compressed bytes -> float vector out +- [ ] Configurable bit budget (2, 3, 4 bits) +- [ ] Accuracy within expected bounds for each bit budget +- [ ] Works through KV append/read APIs +- [ ] Benchmark: compression ratio and throughput + +**Key files to modify:** +- Orchestration in `TurboQuantKvCodec.kt` + +--- + +### TQ-017: SDPA TurboQuant Write Path + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 2 | +| **Priority** | High | +| **Dependencies** | TQ-002, TQ-016 | + +**Description:** +Integrate TurboQuant compression into the SDPA write path so K/V are automatically compressed when stored to the KV cache. + +**Acceptance criteria:** +- [ ] SDPA stores K/V through TurboQuant compression when configured +- [ ] Compression is transparent to callers of `scaledDotProductAttention` +- [ ] Configurable per-layer (some layers can skip compression) +- [ ] No hidden densification + +**Key files to modify:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt` + +--- + +### TQ-018: SDPA TurboQuant Read Path + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 2 | +| **Priority** | High | +| **Dependencies** | TQ-002, TQ-016 | + +**Description:** +Integrate TurboQuant decompression into the SDPA read path so attention is computed against decompressed K/V tiles. + +**Acceptance criteria:** +- [ ] SDPA reads and decompresses only required K/V tiles +- [ ] Tile-level decompression (not full cache) +- [ ] Correct attention scores compared to uncompressed baseline +- [ ] Extension point for fused backend kernels + +**Key files to modify:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt` + +--- + +### TQ-019: Role-Aware K/V Policies + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 3 | +| **Priority** | Medium | +| **Dependencies** | TQ-001, TQ-016 | + +**Description:** +Support independent compression policies for keys and values — different bit budgets, block sizes, and even different variants (e.g., Q8_0 for K + TurboQuant-4 for V). + +**Acceptance criteria:** +- [ ] K and V policies configurable independently +- [ ] Different bit budgets for K vs V +- [ ] Mixed encoding (e.g., Q8_0-K + TurboQuant-V) supported +- [ ] Per-layer policy override +- [ ] Configuration validated at init time + +**Key files to modify:** +- `KvCacheStore` from TQ-001 +- `TurboQuantKvCodec.kt` from TQ-015 + +--- + +### TQ-020: Presets + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Presets | +| **Priority** | Medium | +| **Dependencies** | TQ-019 | + +**Description:** +Implement named preset configurations: +- **safe-lowbit**: Q8_0-K + TurboQuant-4-V +- **balanced**: TurboQuant-4 / TurboQuant-4 +- **experimental-max**: TurboQuant-3 / TurboQuant-3 + +**Acceptance criteria:** +- [ ] Three named presets available +- [ ] Presets resolve to concrete K/V policy configurations +- [ ] Presets selectable via API and DSL +- [ ] Documentation of expected quality/compression trade-offs + +**Key files to create:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantPresets.kt` (new) + +--- + +### TQ-021: DSL / Annotation Support + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Recommended implementation order item 7 | +| **Priority** | Low | +| **Dependencies** | TQ-020 | + +**Description:** +Extend SKaiNET DSL/annotations (`@Place`, `@Weights`) to support TurboQuant KV cache configuration declaratively. + +**Acceptance criteria:** +- [ ] Annotation-based TurboQuant configuration for KV cache +- [ ] Preset selection via annotation +- [ ] Per-layer override via annotation +- [ ] Integrated with existing `PlacementAnnotations.kt` + +**Key files to modify:** +- `skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/PlacementAnnotations.kt` + +--- + +### TQ-022: CPU SIMD Optimization + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Functional requirement 5 | +| **Priority** | Medium | +| **Dependencies** | TQ-016 | + +**Description:** +Optimize TurboQuant kernels (rotation, quantization, bit-packing, dequant) with CPU SIMD using the same pattern as `JvmQuantizedVectorKernels.kt`. + +**Acceptance criteria:** +- [ ] SIMD-optimized rotation kernel +- [ ] SIMD-optimized quant/dequant kernels +- [ ] Benchmark showing speedup over common Kotlin reference +- [ ] Correctness matches reference implementation + +**Key files to create/modify:** +- `skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/` (new kernels) + +--- + +### TQ-023: Metal / Apple Silicon Backend + +| Field | Value | +|---|---| +| **Status** | TODO | +| **PRD section** | Step 2, Functional requirement 5 | +| **Priority** | Medium | +| **Dependencies** | TQ-016 | + +**Description:** +Implement Metal compute shaders for TurboQuant kernels targeting Apple Silicon unified memory. + +**Acceptance criteria:** +- [ ] Metal shader for TurboQuant encode/decode +- [ ] Unified memory path (no CPU-GPU copy for KV cache) +- [ ] Correctness matches CPU reference +- [ ] Benchmark on Apple Silicon + +**Key files to create:** +- Metal backend (new shaders) + +--- + +### TQ-024: Fused Dequant + Attention Kernels + +| Field | Value | +|---|---| +| **Status** | TODO | +| **PRD section** | Step 2, Functional requirement 5 | +| **Priority** | Low — optimization after correctness | +| **Dependencies** | TQ-018, TQ-022 or TQ-023 | + +**Description:** +Fuse TurboQuant decompression with attention score computation to avoid materializing decompressed K/V. + +**Acceptance criteria:** +- [ ] Fused kernel avoids intermediate K/V buffer +- [ ] Correctness matches unfused path +- [ ] Benchmark showing memory and latency improvement +- [ ] At least one backend (CPU SIMD or Metal) + +**Key files to create:** +- Backend-specific fused kernel implementations + +--- + +### TQ-025: TurboQuant Benchmarks + +| Field | Value | +|---|---| +| **Status** | DONE | +| **PRD section** | Step 2, Acceptance criteria | +| **Priority** | High — validates the whole effort | +| **Dependencies** | TQ-016 | + +**Description:** +Add JMH benchmarks for TurboQuant KV compression: encode throughput, decode throughput, compression ratio, attention accuracy degradation. + +**Acceptance criteria:** +- [ ] Encode throughput benchmark (tokens/sec) +- [ ] Decode throughput benchmark (tokens/sec) +- [ ] Compression ratio measurement for each preset +- [ ] Accuracy comparison vs uncompressed KV cache +- [ ] Results documented + +**Key files to create:** +- `skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/TurboQuantBenchmarks.kt` (new) + +--- + +## Dependency Graph + +``` +Step 1 remaining: + TQ-003 (Quants.kt) — independent + TQ-004 (SafeTensors) — independent + TQ-001 (KV-cache) — independent + TQ-002 (SDPA compressed K/V) — depends on TQ-001 + +Step 2: + TQ-010 (Encoding types) — independent + TQ-011 (Rotation) — depends on TQ-010 + TQ-012 (Scalar quant) — depends on TQ-011 + TQ-013 (QJL residual) — depends on TQ-012 + TQ-014 (Bit-packing) — depends on TQ-012 + TQ-015 (KV append/read) — depends on TQ-001, TQ-014 + TQ-016 (PolarOnly e2e) — depends on TQ-011, TQ-012, TQ-014, TQ-015 + TQ-017 (SDPA write) — depends on TQ-002, TQ-016 + TQ-018 (SDPA read) — depends on TQ-002, TQ-016 + TQ-019 (K/V policies) — depends on TQ-001, TQ-016 + TQ-020 (Presets) — depends on TQ-019 + TQ-021 (DSL) — depends on TQ-020 + TQ-022 (CPU SIMD) — depends on TQ-016 + TQ-023 (Metal) — depends on TQ-016 + TQ-024 (Fused kernels) — depends on TQ-018, TQ-022 or TQ-023 + TQ-025 (Benchmarks) — depends on TQ-016 +``` + +## Recommended Implementation Order + +1. **TQ-001** + **TQ-003** + **TQ-004** + **TQ-010** (parallel — no dependencies between them) +2. **TQ-002** + **TQ-011** (after TQ-001 and TQ-010) +3. **TQ-012** (after TQ-011) +4. **TQ-013** + **TQ-014** (parallel, after TQ-012) +5. **TQ-015** (after TQ-001 + TQ-014) +6. **TQ-016** + **TQ-025** (PolarOnly e2e + benchmarks) +7. **TQ-017** + **TQ-018** + **TQ-019** (SDPA integration + policies) +8. **TQ-020** + **TQ-022** + **TQ-023** (presets + backend optimization) +9. **TQ-021** + **TQ-024** (DSL + fused kernels — last) diff --git a/TURBOQUANT_METAL.md b/TURBOQUANT_METAL.md new file mode 100644 index 00000000..ff0f5e70 --- /dev/null +++ b/TURBOQUANT_METAL.md @@ -0,0 +1,325 @@ +# TurboQuant Metal Backend — Implementation Task + +> Covers TQ-023 (Metal/Apple Silicon backend) and TQ-024 (Fused dequant+attention kernels) +> Status: TODO — requires Metal Shading Language + Kotlin/Native interop + +--- + +## Objective + +Implement TurboQuant KV-cache compression and decompression as Metal +compute shaders for Apple Silicon, enabling zero-copy unified-memory +KV cache and fused dequant+attention execution. + +## Why Metal + +- Apple Silicon unified memory eliminates CPU↔GPU copies for KV cache +- Metal Performance Shaders (MPS) provides optimized SDPA primitives +- Most on-device inference for SKaiNET targets macOS/iOS (Apple Silicon) +- TurboQuant decode is embarrassingly parallel — ideal for GPU compute + +## Prerequisites + +All prerequisites are complete: +- [x] TurboQuant encoding types (`TensorEncoding.TurboQuantPolar`, `TurboQuantPolarQjl`) +- [x] CPU reference kernels (rotation, quantize, bit-pack, QJL, codec) +- [x] `KvCacheStore` interface with `TurboQuantKvCacheStore` +- [x] `CompressedKvAttention` bridge with `RAW_STORAGE` extension point +- [x] `Placement` model with `DeviceKind.GPU`, `MemoryDomain.UNIFIED` +- [x] `BufferHandle.DeviceResident` for backend-managed buffers + +## Scope + +### In scope +- Metal compute shaders for TurboQuant encode/decode +- Fused dequant+SDPA Metal kernel +- Unified-memory KV cache (no CPU↔GPU copy) +- Kotlin/Native Metal interop for macOS/iOS targets +- Integration with existing `TensorOps.scaledDotProductAttention()` + +### Out of scope +- General-purpose Metal backend for all TensorOps (separate effort) +- CUDA/Vulkan backends +- Training support (inference only) + +--- + +## Architecture + +### Module structure + +``` +skainet-backends/ + skainet-backend-metal/ # New module + build.gradle.kts # KMP config: macosArm64, iosArm64 + src/ + commonMain/kotlin/sk/ainet/exec/metal/ + MetalTurboQuantOps.kt # Public API + MetalKvCacheStore.kt # Metal-backed KvCacheStore + MetalBufferPool.kt # MTLBuffer lifecycle management + nativeMain/kotlin/sk/ainet/exec/metal/ + MetalDevice.kt # MTLDevice + command queue wrapper + MetalShaderLibrary.kt # Compile & cache .metal shaders + MetalBufferHandle.kt # BufferHandle.DeviceResident for Metal + nativeMain/resources/ + turboquant.metal # Metal compute shaders + nativeTest/ + MetalTurboQuantOpsTest.kt # Correctness vs CPU reference +``` + +### Key interfaces to implement + +```kotlin +// MetalKvCacheStore: KvCacheStore backed by MTLBuffer in unified memory +class MetalKvCacheStore( + config: KvCacheConfig, + keyConfig: TurboQuantConfig, + valueConfig: TurboQuantConfig, + device: MetalDevice +) : KvCacheStore { + // KV data lives in MTLBuffer (unified memory) + // appendToken: GPU-side TurboQuant encode + // readKeys/readValues: GPU-side decode or zero-copy raw access +} + +// MetalTurboQuantOps: dispatch TurboQuant kernels to Metal GPU +class MetalTurboQuantOps(device: MetalDevice) { + fun encode(input: MTLBuffer, config: TurboQuantConfig): MTLBuffer + fun decode(encoded: MTLBuffer, config: TurboQuantConfig): MTLBuffer + fun fusedDequantAttention( + query: MTLBuffer, keyCache: MTLBuffer, valueCache: MTLBuffer, + config: TurboQuantConfig, scale: Float + ): MTLBuffer +} +``` + +### Integration with CompressedKvAttention + +The `RAW_STORAGE` dequant strategy in `CompressedKvAttention` is the +extension point. The Metal backend: +1. Returns raw `TensorStorage` with `BufferHandle.DeviceResident` pointing to MTLBuffer +2. The Metal SDPA kernel reads compressed K/V directly and fuses dequant + +```kotlin +// In MetalAttentionOps (extends or replaces scaledDotProductAttention) +override fun scaledDotProductAttention(query, key, value, mask, scale, causal): Tensor { + val keyStorage = compressedKv.loadKeyStorageRaw(layer) + if (keyStorage.buffer is BufferHandle.DeviceResident) { + // Dispatch fused Metal kernel + return metalOps.fusedDequantAttention(query, keyStorage, valueStorage, ...) + } + // Fallback to CPU + return super.scaledDotProductAttention(query, key, value, mask, scale, causal) +} +``` + +--- + +## Metal Shaders + +### File: `turboquant.metal` + +```metal +// Required compute kernels: + +// 1. turboquant_encode +// Per-thread: rotate → quantize → pack one head's vector +// Threadgroup: shared memory for Walsh-Hadamard butterfly +kernel void turboquant_encode( + device const float* input [[buffer(0)]], // [numHeads, headDim] + device uchar* packed_output [[buffer(1)]], // packed codes + device half* scales_output [[buffer(2)]], // per-group scales + constant TQParams& params [[buffer(3)]], // bits, headDim, seed + uint tid [[thread_position_in_grid]] +); + +// 2. turboquant_decode +// Per-thread: unpack → dequantize → inverse rotate one head's vector +kernel void turboquant_decode( + device const uchar* packed_input [[buffer(0)]], + device const half* scales_input [[buffer(1)]], + device float* output [[buffer(2)]], + constant TQParams& params [[buffer(3)]], + uint tid [[thread_position_in_grid]] +); + +// 3. turboquant_fused_sdpa (highest value kernel) +// Fuses: KV dequant + Q@K^T scaling + softmax + @V +// Avoids materializing decompressed K/V in global memory +kernel void turboquant_fused_sdpa( + device const float* query [[buffer(0)]], // [nHeads, seqLen, headDim] + device const uchar* key_packed [[buffer(1)]], // compressed keys + device const half* key_scales [[buffer(2)]], + device const uchar* value_packed [[buffer(3)]], // compressed values + device const half* value_scales [[buffer(4)]], + device float* output [[buffer(5)]], // [nHeads, seqLen, headDim] + constant SDPAParams& params [[buffer(6)]], + uint2 tid [[thread_position_in_grid]], + uint2 tgid [[threadgroup_position_in_grid]] +); + +// 4. walsh_hadamard_transform +// Threadgroup-cooperative WHT for rotation stage +// Uses threadgroup memory for butterfly communication +kernel void walsh_hadamard_transform( + device float* data [[buffer(0)]], + constant uint& log2_n [[buffer(1)]], + uint tid [[thread_position_in_threadgroup]], + uint tg_size [[threads_per_threadgroup]], + threadgroup float* shared [[threadgroup(0)]] +); +``` + +### Shader parameters + +```metal +struct TQParams { + uint bits; // 2, 3, 4, or 8 + uint headDim; // dimension per head + uint numHeads; // heads in this batch + uint seed; // rotation seed + uint groupSize; // quantization group size (32) + bool useQjl; // whether QJL residual is present + uint residualBits; // QJL residual bits (1-4) +}; + +struct SDPAParams { + uint nHeads; + uint nKVHeads; + uint seqLen; + uint kvLen; + uint headDim; + float scale; // 1/sqrt(headDim) + uint keyBits; + uint valueBits; + bool causal; +}; +``` + +--- + +## Implementation Plan + +### Phase 1: Metal infrastructure (no TurboQuant yet) + +| Task | Description | Files | +|---|---|---| +| M-001 | Create `skainet-backend-metal` module | `build.gradle.kts`, `settings.gradle.kts` | +| M-002 | `MetalDevice` wrapper (MTLDevice, command queue) | `MetalDevice.kt` | +| M-003 | `MetalShaderLibrary` (compile .metal, cache pipelines) | `MetalShaderLibrary.kt` | +| M-004 | `MetalBufferHandle` → `BufferHandle.DeviceResident` | `MetalBufferHandle.kt` | +| M-005 | `MetalBufferPool` (reusable MTLBuffer pool) | `MetalBufferPool.kt` | +| M-006 | Kotlin/Native cinterop for Metal.framework | `metal.def`, build config | + +### Phase 2: TurboQuant encode/decode shaders + +| Task | Description | Files | +|---|---|---| +| M-010 | `turboquant_encode` shader | `turboquant.metal` | +| M-011 | `turboquant_decode` shader | `turboquant.metal` | +| M-012 | `walsh_hadamard_transform` cooperative shader | `turboquant.metal` | +| M-013 | `MetalTurboQuantOps` Kotlin dispatch | `MetalTurboQuantOps.kt` | +| M-014 | Correctness tests vs CPU reference | `MetalTurboQuantOpsTest.kt` | + +### Phase 3: Metal KV cache store + +| Task | Description | Files | +|---|---|---| +| M-020 | `MetalKvCacheStore` with unified-memory buffers | `MetalKvCacheStore.kt` | +| M-021 | GPU-side append (encode on GPU, no CPU round-trip) | shader + Kotlin | +| M-022 | GPU-side read (decode on GPU for raw access) | shader + Kotlin | +| M-023 | Integration with `CompressedKvAttention.RAW_STORAGE` | bridge code | + +### Phase 4: Fused dequant+SDPA + +| Task | Description | Files | +|---|---|---| +| M-030 | `turboquant_fused_sdpa` shader | `turboquant.metal` | +| M-031 | Tiled attention with on-the-fly dequant | shader optimization | +| M-032 | Causal mask support in fused kernel | shader | +| M-033 | GQA (grouped-query attention) support | shader | +| M-034 | End-to-end benchmark vs CPU decode+SDPA | benchmark suite | + +### Phase 5: Integration & optimization + +| Task | Description | Files | +|---|---|---| +| M-040 | Wire Metal backend into `PlatformCpuOpsFactory` for macOS/iOS | factory impl | +| M-041 | Fallback to CPU when Metal unavailable | graceful degradation | +| M-042 | Unified-memory placement resolution in `MemoryPlanner` | planner update | +| M-043 | `@KvCache(device = GPU)` annotation handling | annotation processor | +| M-044 | Performance tuning: threadgroup sizes, occupancy | shader tuning | + +--- + +## Kotlin/Native Metal Interop + +### cinterop definition (`metal.def`) + +``` +language = Objective-C +headers = Metal/Metal.h MetalPerformanceShaders/MetalPerformanceShaders.h +compilerOpts = -framework Metal -framework MetalPerformanceShaders +linkerOpts = -framework Metal -framework MetalPerformanceShaders -framework Foundation +``` + +### Key ObjC types to bridge + +| Metal Type | Kotlin Usage | +|---|---| +| `MTLDevice` | GPU device handle | +| `MTLCommandQueue` | Serial command submission | +| `MTLCommandBuffer` | Batch of GPU commands | +| `MTLComputeCommandEncoder` | Dispatch compute kernels | +| `MTLBuffer` | GPU/unified memory buffer | +| `MTLComputePipelineState` | Compiled shader pipeline | +| `MTLLibrary` | Compiled shader library | + +### Unified memory pattern + +```kotlin +// Allocate in unified memory — accessible from both CPU and GPU +val buffer = device.newBuffer( + length = sizeInBytes, + options = MTLResourceStorageModeShared // unified memory +) + +// CPU can read/write directly (no copy needed) +val ptr = buffer.contents() + +// GPU kernel reads/writes same memory +encoder.setBuffer(buffer, offset = 0, index = 0) +encoder.dispatchThreads(...) +``` + +--- + +## Performance Targets + +| Metric | CPU Reference | Metal Target | +|---|---|---| +| TurboQuant encode (128d, 4-bit) | ~10 μs | < 2 μs | +| TurboQuant decode (128d, 4-bit) | ~8 μs | < 1 μs | +| Fused dequant+SDPA (8 heads, 128d, 1024 KV) | N/A (separate) | < 100 μs | +| KV cache memory (4-bit vs FP32) | 8x compression | 8x compression | +| CPU↔GPU copies for KV cache | N/A | 0 (unified memory) | + +## Acceptance Criteria + +- [ ] Metal shaders compile and run on Apple Silicon (M1+) +- [ ] Encode/decode correctness matches CPU reference within tolerance +- [ ] Fused dequant+SDPA produces correct attention output +- [ ] Zero CPU↔GPU copies for KV cache in unified memory mode +- [ ] Graceful fallback to CPU when Metal is unavailable +- [ ] Benchmark shows meaningful speedup over CPU reference path +- [ ] Works on both macOS (macosArm64) and iOS (iosArm64) + +## References + +- [Metal Shading Language Spec](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf) +- [Metal Best Practices Guide](https://developer.apple.com/library/archive/documentation/3DDrawing/Conceptual/MTLBestPracticesGuide/) +- [MPSGraph Documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph) +- [TurboQuant paper (arXiv)](https://arxiv.org/html/2504.19874v1) +- SKaiNET existing backend: `skainet-backends/skainet-backend-cpu/` +- SKaiNET CPU SIMD kernels: `JvmQuantizedVectorKernels.kt`, `JvmTurboQuantKernels.kt` +- SKaiNET TurboQuant reference: `skainet-lang/.../ops/turboquant/` diff --git a/a.md b/a.md deleted file mode 100644 index ce091074..00000000 --- a/a.md +++ /dev/null @@ -1,254 +0,0 @@ - -❯ rebuild docker comspoe and start - -⏺ Bash(docker compose up --build -d 2>&1) - ⎿   Image postgres:16-alpine Pulling - Image postgres:16-alpine Pulled - Image lila-backend Building - … +199 lines (ctrl+o to expand) - ⎿  (timeout 5m) - -⏺ All containers are up and running: - - - Dashboard: http://localhost:3000 - - Backend API: http://localhost:8000 - - Docs: http://localhost:8081 - - Ollama: http://localhost:11434 - - ollama-init is pulling nomic-embed-text and phi3:mini in the background - -✻ Brewed for 33s - -❯ describe pipeline in as arcticle, use mermaid for data flow and data processing state flow - -⏺ LALI — Local AI Log Intelligence: Pipeline Architecture - - Overview - - LALI is a self-hosted log intelligence platform that transforms raw Kibana CSV exports into actionable insights - through a multi-stage pipeline combining deterministic analysis with LLM-powered reasoning. Every component runs - locally — no data leaves the infrastructure. - - --- - Data Flow - - flowchart LR - CSV["Kibana CSV Export"] - subgraph Ingestion - Parser["Parser"] - Masker["Sensitive Data Masker"] - end - subgraph Storage - LR_DB[("PostgreSQL")] - LogRecords["LogRecords"] - Sessions["SessionRecords"] - Embeddings["LogEmbeddings"] - end - subgraph Analysis Engine - Classification["Classification"] - Clustering["Error Clustering"] - Frequency["Frequency Analysis"] - Anomaly["Anomaly Detection\n(Z-Score)"] - end - subgraph LLM Layer - Ollama["Ollama"] - Summary["Summary Generation"] - Explain["Anomaly Explanation"] - RootCause["Root Cause Analysis"] - Embed["Semantic Embedding"] - SemanticAnom["Semantic Anomaly\nDetection"] - end - subgraph Presentation - API["FastAPI REST API"] - Dashboard["React Dashboard"] - end - - CSV --> Parser --> Masker --> LogRecords --> LR_DB - LogRecords --> Sessions --> LR_DB - - LR_DB --> Classification - LR_DB --> Clustering - LR_DB --> Frequency - LR_DB --> Anomaly - - LR_DB --> Embed --> Embeddings --> LR_DB - Embeddings --> SemanticAnom - - Classification --> Summary - Clustering --> Summary - Clustering --> RootCause - Frequency --> Summary - Anomaly --> Explain - - Summary --> Ollama - Explain --> Ollama - RootCause --> Ollama - Embed --> Ollama - - Classification --> API - Clustering --> API - Frequency --> API - Anomaly --> API - Summary --> API - Explain --> API - RootCause --> API - SemanticAnom --> API - - API --> Dashboard - - --- - Processing State Flow - - Each log record moves through a deterministic state machine from raw text to enriched, analysed data: - - stateDiagram-v2 - [*] --> Uploaded: CSV file received - - Uploaded --> Parsing: parse_csv() - Parsing --> Masked: mask_message() - Masked --> Structured: _parse_message() - - state Structured { - [*] --> FieldExtraction - FieldExtraction --> TypeClassified: log_type assigned - TypeClassified --> TimestampParsed: _parse_timestamp() - } - - Structured --> Stored: store_logs() - Stored --> SessionBuilt: build_sessions() - - state SessionBuilt { - [*] --> Grouped: group by tracking_id - Grouped --> Paired: match request ↔ response - Paired --> Enriched: compute duration, is_error - } - - SessionBuilt --> ReadyForAnalysis - - state ReadyForAnalysis { - [*] --> StatisticalAnalysis - [*] --> LLMAnalysis - - state StatisticalAnalysis { - [*] --> Classified: status + error type - [*] --> Clustered: error code grouping - [*] --> FrequencyAnalysed: time windows + spikes - [*] --> AnomalyDetected: z-score on errors, rate, volume - } - - state LLMAnalysis { - [*] --> Embedded: nomic-embed-text → 768-dim vectors - Embedded --> KMeansClustered: k=10, 20 iterations - KMeansClustered --> SemanticAnomaliesFound: rank by centroid distance - [*] --> RootCausesIdentified: error sessions → LLM grouping - AnomalyDetected --> AnomalyExplained: LLM narrative per window - Classified --> SummaryGenerated: LLM synthesis - } - } - - ReadyForAnalysis --> Served: REST API - Served --> Visualised: React Dashboard - Visualised --> [*] - - --- - Pipeline Stages in Detail - - Stage 1 — Ingestion - - The entry point accepts either a file upload (POST /api/v1/logs/upload) or a server-side path (POST - /api/v1/logs/ingest). The parser handles Kibana's specific CSV dialect: comma-separated with a timestamp column in - "Feb 9, 2026 @ 13:03:50.657" format and a freeform message body. - - Before any data is stored, the masking layer applies 10 regex patterns to redact authorization headers, bearer tokens, - cookies, session IDs, API keys, and passwords. The original text is never persisted — only the masked variant reaches - the database. - - Each message is then classified by type (request_incoming, response_outgoing, or unknown) and its structured fields — - URI, HTTP method, status code, payload presence — are extracted via key-value parsing of the message body. - - Stage 2 — Session Construction - - Raw log records are grouped by tracking_id to reconstruct request-response sessions. The system pairs each incoming - request with its outgoing response, computes round-trip duration in milliseconds, and flags sessions as errors when - the response status is ≥ 400. This produces the SessionRecord table that serves as the foundation for all downstream - analysis. - - Stage 3 — Statistical Analysis - - Four deterministic analysers run against the stored data: - - ┌─────────────────┬───────────────────────────────────────────────────────────┬───────────────────────────────────┐ - │ Analyser │ Method │ Output │ - ├─────────────────┼───────────────────────────────────────────────────────────┼───────────────────────────────────┤ - │ │ Maps status codes to categories; regex-matches error │ Error rate, status distribution, │ - │ Classification │ payloads against 9 known patterns (timeout, auth failure, │ error type distribution │ - │ │ upstream error, etc.) │ │ - ├─────────────────┼───────────────────────────────────────────────────────────┼───────────────────────────────────┤ - │ Error │ Extracts code field from JSON response payloads; groups │ Ranked clusters with counts and │ - │ Clustering │ by error code │ sample tracking IDs │ - ├─────────────────┼───────────────────────────────────────────────────────────┼───────────────────────────────────┤ - │ Frequency │ Buckets all responses into configurable time windows; │ Time series, spike list, top │ - │ Analysis │ detects spikes where error rate exceeds 2× the overall │ errors by status and code │ - │ │ average │ │ - ├─────────────────┼───────────────────────────────────────────────────────────┼───────────────────────────────────┤ - │ Anomaly │ Computes z-scores on three metrics (error count, error │ Anomaly list with z-scores, │ - │ Detection │ rate, request volume) per time window; flags windows │ severity (warning / critical), │ - │ │ exceeding the threshold │ and stats │ - └─────────────────┴───────────────────────────────────────────────────────────┴───────────────────────────────────┘ - - Stage 4 — LLM-Powered Intelligence - - This layer uses a locally-hosted Ollama instance to add reasoning capabilities that go beyond pattern matching. Every - LLM feature includes a rule-based fallback for environments where Ollama is unavailable. - - Semantic Embedding & Anomaly Detection — Log messages are embedded into 768-dimensional vectors via nomic-embed-text - through Ollama's /api/embed endpoint. A pure-Python k-means implementation (k=10, 20 iterations, seeded for - reproducibility) clusters these vectors. Logs with the greatest cosine distance to their assigned centroid are - surfaced as semantic anomalies — patterns the system has never seen before, regardless of HTTP status code. - - Root Cause Analysis — Error sessions are collected, serialised as JSON, and sent to the LLM with instructions to group - them into named root causes with severity ratings and explanations. The fallback path uses the existing error - clustering logic and assigns severity based on occurrence count thresholds (≥10 critical, ≥5 high, ≥2 medium). - - Anomaly Explanation — For any detected statistical anomaly, the system retrieves the logs in that time window and asks - the LLM to produce a 2–4 sentence narrative explaining the likely cause. The fallback constructs a factual summary - citing the top error status and affected endpoints. - - Summary Generation — Combines classification, clustering, and frequency data into a structured prompt. The LLM returns - a 3–5 sentence health assessment covering system status, critical issues, and recommended actions. - - Stage 5 — Presentation - - The FastAPI backend exposes 11 REST endpoints under /api/v1/. The React dashboard consumes these through a typed API - client and renders: - - - Overview — health metrics, status distribution pie charts, top error clusters - - Errors — error cluster table with sample tracking IDs - - Time Series — request volume bar chart, error rate line chart, anomaly markers with per-row "Explain" buttons that - expand to show LLM narratives - - Sessions — filterable session list with drill-down to individual request/response logs - - Semantic Anomalies — embed trigger button, anomaly table ranked by centroid distance - - Root Causes — severity-coded cards with explanations and sample tracking IDs - - --- - Infrastructure - - flowchart TB - subgraph Docker Compose - PG["PostgreSQL 16"] - OL["Ollama Server"] - OI["ollama-init\n(pulls nomic-embed-text\n+ phi3:mini)"] - BE["FastAPI Backend\n:8000"] - FE["React Dashboard\n:3000"] - DO["Docs Site\n:8081"] - end - - OI -->|depends_on| OL - BE -->|depends_on| PG - BE -->|depends_on| OL - FE -->|depends_on| BE - - BE -- "DATABASE_URL" --> PG - BE -- "OLLAMA_BASE_URL\nOLLAMA_MODEL\nOLLAMA_EMBED_MODEL" --> OL - - All models run on the Ollama instance — phi3:mini for text generation, nomic-embed-text for embeddings. The - ollama-init sidecar container pulls both models on first startup so the system is ready without manual intervention. diff --git a/docs/io-readers-guide.md b/docs/io-readers-guide.md index 918b8f53..d431a7c3 100644 --- a/docs/io-readers-guide.md +++ b/docs/io-readers-guide.md @@ -35,6 +35,38 @@ dependencies { ## GGUF Reader Usage +> **Recommended:** For large model files, use `StreamingGGUFReader` instead of `GGUFReader`. +> The streaming reader parses only metadata (~1 MB) and loads tensors on-demand, supporting +> files over 100 GB without heap-loading the entire file. It also supports quantized types +> (Q4_K, Q8_0, etc.) via `StreamingGgufParametersLoader`. See the streaming examples below. + +### Streaming GGUF Reading (Recommended) + +```kotlin +import sk.ainet.io.JvmRandomAccessSource +import sk.ainet.io.gguf.StreamingGGUFReader + +fun readLargeModel(filePath: String) { + val source = JvmRandomAccessSource.open(filePath) + StreamingGGUFReader.open(source).use { reader -> + println("Tensors: ${reader.tensorCount}") + println("Architecture: ${reader.fields["general.architecture"]}") + + // Load specific tensor on demand + val weights = reader.loadTensor("token_embd.weight") + + // Or get a TensorStorage descriptor with metadata + val storage = reader.loadTensorStorage("token_embd.weight") + println("Encoding: ${storage.encoding.name}, Physical: ${storage.physicalBytes} bytes") + } +} +``` + +### Legacy GGUF Reading + +> **Note:** The legacy `GGUFReader` loads the entire file into memory and only supports +> F32/I32 tensors. Prefer `StreamingGGUFReader` for new code. + ### Basic GGUF Reading ```kotlin diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt index 8e4ff76a..5f9c7e90 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt @@ -769,6 +769,72 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory } } + override fun convTranspose1d( + input: Tensor, + weight: Tensor, + bias: Tensor?, + stride: Int, + padding: Int, + outputPadding: Int, + dilation: Int, + groups: Int + ): Tensor { + // input: [batch, inChannels, inLength] + // weight: [inChannels, outChannels/groups, kernelSize] + val batch = input.shape[0] + val inChannels = input.shape[1] + val inLength = input.shape[2] + val outChannelsPerGroup = weight.shape[1] + val kernelSize = weight.shape[2] + val outChannels = outChannelsPerGroup * groups + val outLength = (inLength - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + outputPadding + 1 + + val outData = dataFactory.zeros(Shape(batch, outChannels, outLength), input.dtype) + + val inData = input.data + val wData = weight.data + + val inChPerGroup = inChannels / groups + + for (b in 0 until batch) { + for (g in 0 until groups) { + for (ic in 0 until inChPerGroup) { + for (oc in 0 until outChannelsPerGroup) { + for (il in 0 until inLength) { + val inputVal = inData.get(b, g * inChPerGroup + ic, il) as Float + if (inputVal == 0f) continue + for (k in 0 until kernelSize) { + val ol = il * stride - padding + k * dilation + if (ol < 0 || ol >= outLength) continue + val weightVal = wData.get(g * inChPerGroup + ic, oc, k) as Float + val existing = outData.get(b, g * outChannelsPerGroup + oc, ol) as Float + @Suppress("UNCHECKED_CAST") + outData.set(b, g * outChannelsPerGroup + oc, ol, value = (existing + inputVal * weightVal) as V) + } + } + } + } + } + } + + // Add bias + if (bias != null) { + val biasData = bias.data + for (b in 0 until batch) { + for (oc in 0 until outChannels) { + val biasVal = biasData.get(oc) as Float + for (ol in 0 until outLength) { + val existing = outData.get(b, oc, ol) as Float + @Suppress("UNCHECKED_CAST") + outData.set(b, oc, ol, value = (existing + biasVal) as V) + } + } + } + } + + return newTensor(outData, input.dtype, input) + } + @TensorOp() override fun conv3d( input: Tensor, @@ -2262,6 +2328,33 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory return newTensor(outData, tensor.dtype, tensor) } + override fun sin(tensor: Tensor): Tensor { + val outData = dataFactory.init(tensor.shape, tensor.dtype) { idx -> + val x = tensor.data.get(*idx) as Float + @Suppress("UNCHECKED_CAST") + kotlin.math.sin(x) as V + } + return newTensor(outData, tensor.dtype, tensor) + } + + override fun cos(tensor: Tensor): Tensor { + val outData = dataFactory.init(tensor.shape, tensor.dtype) { idx -> + val x = tensor.data.get(*idx) as Float + @Suppress("UNCHECKED_CAST") + kotlin.math.cos(x) as V + } + return newTensor(outData, tensor.dtype, tensor) + } + + override fun tanh(tensor: Tensor): Tensor { + val outData = dataFactory.init(tensor.shape, tensor.dtype) { idx -> + val x = tensor.data.get(*idx) as Float + @Suppress("UNCHECKED_CAST") + kotlin.math.tanh(x).toFloat() as V + } + return newTensor(outData, tensor.dtype, tensor) + } + override fun scaledDotProductAttention( query: Tensor, key: Tensor, diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmTurboQuantKernels.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmTurboQuantKernels.kt new file mode 100644 index 00000000..caf80b1e --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmTurboQuantKernels.kt @@ -0,0 +1,203 @@ +package sk.ainet.exec.tensor.ops + +import jdk.incubator.vector.FloatVector +import jdk.incubator.vector.VectorOperators +import jdk.incubator.vector.VectorSpecies +import sk.ainet.lang.tensor.ops.turboquant.BitPacker +import sk.ainet.lang.tensor.ops.turboquant.QuantizedVector +import sk.ainet.lang.tensor.ops.turboquant.ScalarQuantizer +import kotlin.math.abs +import kotlin.math.max +import kotlin.math.min +import kotlin.math.round + +/** + * JVM SIMD-optimized kernels for TurboQuant operations. + * + * Uses the Java Vector API (jdk.incubator.vector) for CPU SIMD acceleration + * of TurboQuant encode/decode paths. Falls back to scalar code for + * non-aligned tails. + * + * These kernels optimize the hot paths: + * - Per-group abs-max computation (for scale calculation) + * - Vectorized quantization (float → code) + * - Vectorized dequantization (code → float) + * - Walsh-Hadamard transform butterfly stages + * + * Usage: Called by the CPU backend when TurboQuant-encoded K/V is detected + * in the attention path. + */ +public object JvmTurboQuantKernels { + + private val FLOAT_SPECIES: VectorSpecies = FloatVector.SPECIES_PREFERRED + private val floatStep: Int = FLOAT_SPECIES.length() + + // ========== Vectorized abs-max (for scale computation) ========== + + /** + * Find the maximum absolute value in a float array segment. + * SIMD-accelerated with scalar tail. + */ + public fun absMax(data: FloatArray, offset: Int, length: Int): Float { + var maxVec = FloatVector.zero(FLOAT_SPECIES) + val end = offset + length + val loopBound = FLOAT_SPECIES.loopBound(length) + offset + var i = offset + + // Vectorized loop + while (i < loopBound) { + val v = FloatVector.fromArray(FLOAT_SPECIES, data, i) + maxVec = maxVec.max(v.abs()) + i += floatStep + } + + // Reduce vector to scalar + var result = maxVec.reduceLanes(VectorOperators.MAX) + + // Scalar tail + while (i < end) { + result = max(result, abs(data[i])) + i++ + } + return result + } + + // ========== Vectorized quantization ========== + + /** + * SIMD-accelerated scalar quantization with per-group scales. + * + * Replaces [ScalarQuantizer.quantize] for the hot path. + */ + public fun quantize(input: FloatArray, bits: Int): QuantizedVector { + val maxCode = (1 shl (bits - 1)) - 1 + val groupSize = ScalarQuantizer.GROUP_SIZE + val numGroups = (input.size + groupSize - 1) / groupSize + val scales = FloatArray(numGroups) + val codes = ByteArray(input.size) + + for (g in 0 until numGroups) { + val start = g * groupSize + val end = min(start + groupSize, input.size) + val groupLen = end - start + + // SIMD abs-max + val absMax = absMax(input, start, groupLen) + val scale = if (absMax > 0f) absMax / maxCode else 0f + scales[g] = scale + + if (scale > 0f) { + val invScale = 1f / scale + val invScaleVec = FloatVector.broadcast(FLOAT_SPECIES, invScale) + val maxCodeF = maxCode.toFloat() + val minCodeF = -maxCode.toFloat() + val maxVec = FloatVector.broadcast(FLOAT_SPECIES, maxCodeF) + val minVec = FloatVector.broadcast(FLOAT_SPECIES, minCodeF) + + val loopBound = FLOAT_SPECIES.loopBound(groupLen) + start + var i = start + + // Vectorized quantize + while (i < loopBound) { + val v = FloatVector.fromArray(FLOAT_SPECIES, input, i) + val scaled = v.mul(invScaleVec) + // Clamp to [-maxCode, maxCode] + val clamped = scaled.min(maxVec).max(minVec) + // Convert to int codes (round) + for (j in 0 until floatStep) { + codes[i + j] = round(clamped.lane(j)).toInt().toByte() + } + i += floatStep + } + + // Scalar tail + while (i < end) { + val q = round(input[i] * invScale).toInt() + codes[i] = q.coerceIn(-maxCode, maxCode).toByte() + i++ + } + } + } + + return QuantizedVector(codes, scales, bits) + } + + // ========== Vectorized dequantization ========== + + /** + * SIMD-accelerated dequantization. + * + * Replaces [ScalarQuantizer.dequantize] for the hot path. + */ + public fun dequantize(codes: ByteArray, scales: FloatArray, output: FloatArray, offset: Int = 0) { + val groupSize = ScalarQuantizer.GROUP_SIZE + + for (g in scales.indices) { + val start = g * groupSize + val end = min(start + groupSize, codes.size) + val groupLen = end - start + val scale = scales[g] + val scaleVec = FloatVector.broadcast(FLOAT_SPECIES, scale) + + val loopBound = FLOAT_SPECIES.loopBound(groupLen) + start + var i = start + + // Vectorized dequant: output = code * scale + while (i < loopBound) { + // Load codes as floats + val floats = FloatArray(floatStep) + for (j in 0 until floatStep) { + floats[j] = codes[i + j].toFloat() + } + val codeVec = FloatVector.fromArray(FLOAT_SPECIES, floats, 0) + val result = codeVec.mul(scaleVec) + result.intoArray(output, offset + i) + i += floatStep + } + + // Scalar tail + while (i < end) { + output[offset + i] = codes[i].toFloat() * scale + i++ + } + } + } + + // ========== Vectorized Walsh-Hadamard butterfly ========== + + /** + * SIMD-accelerated Walsh-Hadamard transform butterfly stage. + * + * Each butterfly stage computes: (a, b) → (a+b, a-b) for pairs + * separated by stride `h`. The SIMD version processes multiple + * pairs simultaneously. + */ + public fun walshHadamardButterfly(data: FloatArray, h: Int, len: Int) { + var i = 0 + while (i < len) { + var j = i + val jEnd = i + h + val loopBound = FLOAT_SPECIES.loopBound(h) + i + + // Vectorized butterfly + while (j < loopBound) { + val a = FloatVector.fromArray(FLOAT_SPECIES, data, j) + val b = FloatVector.fromArray(FLOAT_SPECIES, data, j + h) + a.add(b).intoArray(data, j) + a.sub(b).intoArray(data, j + h) + j += floatStep + } + + // Scalar tail + while (j < jEnd) { + val x = data[j] + val y = data[j + h] + data[j] = x + y + data[j + h] = x - y + j++ + } + + i += h * 2 + } + } +} diff --git a/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/FallbackMappedMemoryChunk.kt b/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/FallbackMappedMemoryChunk.kt new file mode 100644 index 00000000..4b930511 --- /dev/null +++ b/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/FallbackMappedMemoryChunk.kt @@ -0,0 +1,65 @@ +package sk.ainet.io + +/** + * Fallback [MappedMemoryChunk] implementation backed by a heap [ByteArray]. + * + * Used on platforms without native mmap support (JS, Wasm). The data is + * eagerly loaded into memory, so this does not provide the OS-paged + * benefits of a true memory-mapped file. It does, however, satisfy the + * [MappedMemoryChunk] contract so that code written against that interface + * works on all Kotlin Multiplatform targets. + */ +public class FallbackMappedMemoryChunk( + override val path: String, + override val fileOffset: Long, + private val data: ByteArray, + private val dataOffset: Int = 0, + override val size: Long = (data.size - dataOffset).toLong() +) : MappedMemoryChunk { + + override fun readByte(offset: Long): Byte { + require(offset in 0 until size) { "Offset out of bounds: $offset (size=$size)" } + return data[dataOffset + offset.toInt()] + } + + override fun readBytes(offset: Long, length: Int): ByteArray { + require(offset >= 0 && offset + length <= size) { + "Range out of bounds: offset=$offset length=$length size=$size" + } + return data.copyOfRange(dataOffset + offset.toInt(), dataOffset + offset.toInt() + length) + } + + override fun slice(offset: Long, length: Long): MemoryChunk { + require(offset >= 0 && offset + length <= size) { + "Slice out of bounds: offset=$offset length=$length size=$size" + } + return FallbackMappedMemoryChunk( + path = path, + fileOffset = fileOffset + offset, + data = data, + dataOffset = dataOffset + offset.toInt(), + size = length + ) + } + + override fun close() { + // No-op: heap memory is GC'd + } + + public companion object { + /** + * Create a fallback chunk by reading from a [RandomAccessSource]. + * This eagerly loads the region into heap — use JvmMappedMemoryChunk + * on JVM for true mmap. + */ + public fun fromSource( + source: RandomAccessSource, + path: String, + offset: Long = 0, + length: Long = source.size - offset + ): FallbackMappedMemoryChunk { + val data = source.readAt(offset, length.toInt()) + return FallbackMappedMemoryChunk(path, offset, data) + } + } +} diff --git a/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/MappedMemoryChunk.kt b/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/MappedMemoryChunk.kt new file mode 100644 index 00000000..4356b6dc --- /dev/null +++ b/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/MappedMemoryChunk.kt @@ -0,0 +1,20 @@ +package sk.ainet.io + +/** + * A [MemoryChunk] backed by a memory-mapped file region. + * + * On platforms that support mmap (JVM, native), this avoids loading the + * entire region into heap memory — the OS pages data in on demand. On + * platforms without mmap support (JS, Wasm), the factory falls back to + * reading the region into a [ByteArrayMemoryChunk]. + * + * Instances are immutable from the runtime's perspective. + */ +public interface MappedMemoryChunk : MemoryChunk, AutoCloseable { + + /** The file path this chunk is mapped from. */ + public val path: String + + /** The byte offset within the file where the mapping starts. */ + public val fileOffset: Long +} diff --git a/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/tokenizer/TekkenTokenizer.kt b/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/tokenizer/TekkenTokenizer.kt new file mode 100644 index 00000000..64212e5b --- /dev/null +++ b/skainet-io/skainet-io-core/src/commonMain/kotlin/sk/ainet/io/tokenizer/TekkenTokenizer.kt @@ -0,0 +1,472 @@ +package sk.ainet.io.tokenizer + +import kotlin.io.encoding.Base64 +import kotlin.io.encoding.ExperimentalEncodingApi + +/** + * Mistral Tekken tokenizer implementation. + * + * Tekken is a tiktoken-based BPE tokenizer used by Mistral models (Mistral, + * Mixtral, Codestral, Voxtral, etc.). Unlike HuggingFace tokenizer.json, + * tekken.json uses: + * - Base64-encoded byte sequences for vocab tokens + * - Implicit merge ordering from vocab rank (lower rank = higher priority) + * - Separate special token list with reserved ID space at [0, numSpecialTokens) + * - tiktoken-style pre-tokenization regex pattern + * + * Token ID layout: + * ``` + * IDs [0, numSpecialTokens) → special tokens (, , , [INST], ...) + * IDs [numSpecialTokens, ...] → vocab tokens (rank 0..N offset by numSpecialTokens) + * ``` + * + * @param vocabTokenBytes List of byte arrays, indexed by rank (rank 0 = first 256 are single bytes) + * @param vocabTokenStrings List of optional string representations, indexed by rank + * @param specialTokens Map of special token string → token ID + * @param specialTokensById Map of token ID → special token string (for decoding) + * @param numSpecialTokens Number of reserved special token IDs (default: 1000) + * @param pattern Pre-tokenization regex pattern (tiktoken-style) + */ +public class TekkenTokenizer( + private val vocabTokenBytes: List, + private val vocabTokenStrings: List, + private val specialTokens: Map, + private val specialTokensById: Map, + private val numSpecialTokens: Int = 1000, + private val pattern: Regex +) { + /** BPE rank lookup: byte sequence → rank (merge priority). */ + private val bytesToRank: HashMap = HashMap(vocabTokenBytes.size * 2) + + init { + for (i in vocabTokenBytes.indices) { + bytesToRank[ByteArrayKey(vocabTokenBytes[i])] = i + } + } + + /** Number of vocab tokens (excluding special tokens). */ + public val vocabSize: Int get() = vocabTokenBytes.size + + /** Total token count (vocab + special tokens). */ + public val totalTokens: Int get() = vocabTokenBytes.size + numSpecialTokens + + /** BOS token ID. */ + public val bosTokenId: Int get() = specialTokens[""] ?: 1 + + /** EOS token ID. */ + public val eosTokenId: Int get() = specialTokens[""] ?: 2 + + /** + * Encode text to token IDs. + * + * 1. Split text using pre-tokenization regex pattern + * 2. For each chunk, convert to bytes and apply BPE merges + * 3. Offset ranks by numSpecialTokens to get final IDs + */ + public fun encode(text: String): IntArray { + val tokens = mutableListOf() + + // Check for special tokens in the text first + var remaining = text + while (remaining.isNotEmpty()) { + // Try to match a special token at current position + var matchedSpecial = false + for ((token, id) in specialTokens) { + if (remaining.startsWith(token)) { + tokens.add(id) + remaining = remaining.substring(token.length) + matchedSpecial = true + break + } + } + if (matchedSpecial) continue + + // Find the next special token position (or end of string) + var nextSpecialPos = remaining.length + for (token in specialTokens.keys) { + val pos = remaining.indexOf(token) + if (pos in 1 until nextSpecialPos) { + nextSpecialPos = pos + } + } + + // Encode the non-special segment + val segment = remaining.substring(0, nextSpecialPos) + remaining = remaining.substring(nextSpecialPos) + + // Pre-tokenize with regex pattern + val matches = pattern.findAll(segment) + for (match in matches) { + val chunk = match.value + val chunkBytes = chunk.encodeToByteArray() + val merged = bpeMerge(chunkBytes) + for (rank in merged) { + tokens.add(rank + numSpecialTokens) + } + } + } + + return tokens.toIntArray() + } + + /** + * Decode token IDs to text. + */ + public fun decode(tokens: IntArray): String { + val bytes = mutableListOf() + val result = StringBuilder() + + for (id in tokens) { + if (id < numSpecialTokens) { + // Flush accumulated bytes + if (bytes.isNotEmpty()) { + result.append(bytes.toByteArray().decodeToString()) + bytes.clear() + } + result.append(specialTokensById[id] ?: "") + } else { + val rank = id - numSpecialTokens + if (rank in vocabTokenBytes.indices) { + bytes.addAll(vocabTokenBytes[rank].toList()) + } + } + } + + // Flush remaining bytes + if (bytes.isNotEmpty()) { + result.append(bytes.toByteArray().decodeToString()) + } + + return result.toString() + } + + /** + * Decode a single token ID to text. + */ + public fun decode(token: Int): String { + if (token < numSpecialTokens) { + return specialTokensById[token] ?: "" + } + val rank = token - numSpecialTokens + if (rank in vocabTokenBytes.indices) { + return vocabTokenBytes[rank].decodeToString() + } + return "" + } + + /** + * Apply BPE merges to a byte sequence. + * + * tiktoken BPE: repeatedly find the pair of adjacent tokens with the + * lowest rank and merge them, until no more merges are possible. + * + * @param bytes Input byte sequence + * @return List of vocab ranks (NOT token IDs — caller adds numSpecialTokens offset) + */ + private fun bpeMerge(bytes: ByteArray): List { + if (bytes.isEmpty()) return emptyList() + + // Start with single-byte tokens (ranks 0-255) + val pieces = ArrayList(bytes.size) + for (b in bytes) { + pieces.add(byteArrayOf(b)) + } + + while (pieces.size > 1) { + // Find the pair with lowest merge rank + var bestRank = Int.MAX_VALUE + var bestIdx = -1 + + for (i in 0 until pieces.size - 1) { + val merged = concat(pieces[i], pieces[i + 1]) + val rank = bytesToRank[ByteArrayKey(merged)] + if (rank != null && rank < bestRank) { + bestRank = rank + bestIdx = i + } + } + + if (bestIdx == -1) break // no more merges possible + + // Apply the merge + val merged = concat(pieces[bestIdx], pieces[bestIdx + 1]) + pieces[bestIdx] = merged + pieces.removeAt(bestIdx + 1) + } + + // Convert byte sequences to ranks + return pieces.map { piece -> + bytesToRank[ByteArrayKey(piece)] + ?: error("BPE produced unknown byte sequence: ${piece.toList()}") + } + } + + private fun concat(a: ByteArray, b: ByteArray): ByteArray { + val result = ByteArray(a.size + b.size) + a.copyInto(result) + b.copyInto(result, a.size) + return result + } + + public companion object { + /** + * Parse a tekken.json string into a [TekkenTokenizer]. + * + * Uses lightweight JSON parsing to avoid kotlinx.serialization dependency + * in the tokenizer itself (the JSON structure is simple enough). + */ + @OptIn(ExperimentalEncodingApi::class) + public fun fromJson(json: String): TekkenTokenizer { + val parser = TekkenJsonParser(json) + return parser.parse() + } + } +} + +/** + * Wrapper for ByteArray that implements equals/hashCode for use as HashMap key. + */ +internal class ByteArrayKey(val bytes: ByteArray) { + override fun equals(other: Any?): Boolean { + if (other !is ByteArrayKey) return false + return bytes.contentEquals(other.bytes) + } + + override fun hashCode(): Int = bytes.contentHashCode() +} + +/** + * Lightweight parser for tekken.json format. + */ +@OptIn(ExperimentalEncodingApi::class) +internal class TekkenJsonParser(private val json: String) { + + fun parse(): TekkenTokenizer { + // Extract config + val numSpecialTokens = extractInt("default_num_special_tokens") ?: 1000 + val patternStr = extractString("pattern") + ?: "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|[\\p{N}]+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s+(?!\\S)|\\s+" + + // Parse vocab array + val vocabTokenBytes = mutableListOf() + val vocabTokenStrings = mutableListOf() + parseVocabEntries(vocabTokenBytes, vocabTokenStrings) + + // Parse special tokens + val specialTokens = mutableMapOf() + val specialTokensById = mutableMapOf() + parseSpecialTokens(specialTokens, specialTokensById) + + // Compile regex pattern + val pattern = try { + Regex(patternStr) + } catch (e: Exception) { + // Fallback: split on whitespace boundaries + Regex("\\S+|\\s+") + } + + return TekkenTokenizer( + vocabTokenBytes = vocabTokenBytes, + vocabTokenStrings = vocabTokenStrings, + specialTokens = specialTokens, + specialTokensById = specialTokensById, + numSpecialTokens = numSpecialTokens, + pattern = pattern + ) + } + + private fun parseVocabEntries( + tokenBytes: MutableList, + tokenStrings: MutableList + ) { + // Find "vocab" array + val vocabStart = json.indexOf("\"vocab\"") + if (vocabStart < 0) return + + val arrayStart = json.indexOf('[', vocabStart) + if (arrayStart < 0) return + + // Parse each entry: {"rank": N, "token_bytes": "...", "token_str": "..."} + var pos = arrayStart + 1 + while (pos < json.length) { + pos = skipWhitespace(pos) + if (pos >= json.length || json[pos] == ']') break + + if (json[pos] == '{') { + val objEnd = findMatchingBrace(pos) + val obj = json.substring(pos, objEnd + 1) + + val tokenBytesB64 = extractStringFromObj(obj, "token_bytes") + val tokenStr = extractStringFromObj(obj, "token_str") + + if (tokenBytesB64 != null) { + val decoded = Base64.decode(tokenBytesB64) + tokenBytes.add(decoded) + tokenStrings.add(tokenStr) + } + + pos = objEnd + 1 + } else { + pos++ + } + + pos = skipWhitespace(pos) + if (pos < json.length && json[pos] == ',') pos++ + } + } + + private fun parseSpecialTokens( + specialTokens: MutableMap, + specialTokensById: MutableMap + ) { + val stStart = json.indexOf("\"special_tokens\"") + if (stStart < 0) return + + val arrayStart = json.indexOf('[', stStart) + if (arrayStart < 0) return + + var pos = arrayStart + 1 + while (pos < json.length) { + pos = skipWhitespace(pos) + if (pos >= json.length || json[pos] == ']') break + + if (json[pos] == '{') { + val objEnd = findMatchingBrace(pos) + val obj = json.substring(pos, objEnd + 1) + + val rank = extractIntFromObj(obj, "rank") + val tokenStr = extractStringFromObj(obj, "token_str") + + if (rank != null && tokenStr != null) { + specialTokens[tokenStr] = rank + specialTokensById[rank] = tokenStr + } + + pos = objEnd + 1 + } else { + pos++ + } + + pos = skipWhitespace(pos) + if (pos < json.length && json[pos] == ',') pos++ + } + } + + // ========== JSON helpers ========== + + private fun extractInt(key: String): Int? { + val keyStr = "\"$key\"" + val idx = json.indexOf(keyStr) + if (idx < 0) return null + var pos = idx + keyStr.length + pos = skipWhitespace(pos) + if (pos < json.length && json[pos] == ':') pos++ + pos = skipWhitespace(pos) + val start = pos + while (pos < json.length && (json[pos].isDigit() || json[pos] == '-')) pos++ + return json.substring(start, pos).toIntOrNull() + } + + private fun extractString(key: String): String? { + val keyStr = "\"$key\"" + val idx = json.indexOf(keyStr) + if (idx < 0) return null + var pos = idx + keyStr.length + pos = skipWhitespace(pos) + if (pos < json.length && json[pos] == ':') pos++ + pos = skipWhitespace(pos) + if (pos >= json.length || json[pos] != '"') return null + return readJsonString(pos) + } + + private fun extractStringFromObj(obj: String, key: String): String? { + val keyStr = "\"$key\"" + val idx = obj.indexOf(keyStr) + if (idx < 0) return null + var pos = idx + keyStr.length + while (pos < obj.length && (obj[pos] == ' ' || obj[pos] == ':')) pos++ + if (pos >= obj.length) return null + if (obj[pos] == 'n' && obj.startsWith("null", pos)) return null + if (obj[pos] != '"') return null + return readJsonStringFrom(obj, pos) + } + + private fun extractIntFromObj(obj: String, key: String): Int? { + val keyStr = "\"$key\"" + val idx = obj.indexOf(keyStr) + if (idx < 0) return null + var pos = idx + keyStr.length + while (pos < obj.length && (obj[pos] == ' ' || obj[pos] == ':')) pos++ + val start = pos + while (pos < obj.length && (obj[pos].isDigit() || obj[pos] == '-')) pos++ + return obj.substring(start, pos).toIntOrNull() + } + + private fun readJsonString(startPos: Int): String { + return readJsonStringFrom(json, startPos) + } + + private fun readJsonStringFrom(s: String, startPos: Int): String { + val sb = StringBuilder() + var pos = startPos + 1 // skip opening quote + while (pos < s.length) { + val c = s[pos] + when { + c == '"' -> return sb.toString() + c == '\\' && pos + 1 < s.length -> { + pos++ + when (s[pos]) { + '"' -> sb.append('"') + '\\' -> sb.append('\\') + '/' -> sb.append('/') + 'n' -> sb.append('\n') + 'r' -> sb.append('\r') + 't' -> sb.append('\t') + 'b' -> sb.append('\b') + 'f' -> sb.append('\u000C') + 'u' -> { + if (pos + 4 < s.length) { + val hex = s.substring(pos + 1, pos + 5) + val cp = hex.toIntOrNull(16) ?: 0 + sb.append(cp.toChar()) + pos += 4 + } + } + } + } + else -> sb.append(c) + } + pos++ + } + return sb.toString() + } + + private fun skipWhitespace(pos: Int): Int { + var p = pos + while (p < json.length && json[p].isWhitespace()) p++ + return p + } + + private fun findMatchingBrace(start: Int): Int { + var depth = 0 + var inString = false + var pos = start + while (pos < json.length) { + val c = json[pos] + when { + inString -> { + if (c == '"') inString = false + else if (c == '\\') pos++ + } + c == '"' -> inString = true + c == '{' -> depth++ + c == '}' -> { + depth-- + if (depth == 0) return pos + } + } + pos++ + } + return json.length - 1 + } +} diff --git a/skainet-io/skainet-io-core/src/commonTest/kotlin/sk/ainet/io/FallbackMappedMemoryChunkTest.kt b/skainet-io/skainet-io-core/src/commonTest/kotlin/sk/ainet/io/FallbackMappedMemoryChunkTest.kt new file mode 100644 index 00000000..bf4281d2 --- /dev/null +++ b/skainet-io/skainet-io-core/src/commonTest/kotlin/sk/ainet/io/FallbackMappedMemoryChunkTest.kt @@ -0,0 +1,92 @@ +package sk.ainet.io + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class FallbackMappedMemoryChunkTest { + + private fun chunk(data: ByteArray = ByteArray(100) { it.toByte() }) = + FallbackMappedMemoryChunk( + path = "/test/file.bin", + fileOffset = 0, + data = data + ) + + @Test + fun readByte_validOffset_returnsCorrectByte() { + val c = chunk() + assertEquals(0.toByte(), c.readByte(0)) + assertEquals(99.toByte(), c.readByte(99)) + } + + @Test + fun readByte_outOfBounds_throws() { + val c = chunk() + assertFailsWith { c.readByte(-1) } + assertFailsWith { c.readByte(100) } + } + + @Test + fun readBytes_range_returnsCorrectSubarray() { + val c = chunk() + val bytes = c.readBytes(10, 3) + assertEquals(3, bytes.size) + assertEquals(10.toByte(), bytes[0]) + assertEquals(12.toByte(), bytes[2]) + } + + @Test + fun readBytes_outOfBounds_throws() { + val c = chunk() + assertFailsWith { c.readBytes(98, 5) } // 98+5 > 100 + } + + @Test + fun slice_returnsSubChunk() { + val c = chunk() + val s = c.slice(50, 20) + assertEquals(20L, s.size) + assertEquals(50.toByte(), s.readByte(0)) + assertEquals(69.toByte(), s.readByte(19)) + } + + @Test + fun slice_ofSlice_composesOffsets() { + val c = chunk() + val s1 = c.slice(10, 50) as FallbackMappedMemoryChunk + val s2 = s1.slice(5, 10) + assertEquals(10L, s2.size) + // Should read from original data at offset 10+5=15 + assertEquals(15.toByte(), s2.readByte(0)) + } + + @Test + fun slice_outOfBounds_throws() { + val c = chunk() + assertFailsWith { c.slice(90, 20) } // 90+20 > 100 + } + + @Test + fun constructorWithDataOffset_readsFromOffset() { + val data = ByteArray(50) { (it + 10).toByte() } + val c = FallbackMappedMemoryChunk("/f.bin", 0, data, dataOffset = 10, size = 20) + assertEquals(20L, c.size) + assertEquals(20.toByte(), c.readByte(0)) // data[10] = 10+10 = 20 + } + + @Test + fun pathAndFileOffset_arePreserved() { + val c = FallbackMappedMemoryChunk("/model/weights.bin", fileOffset = 4096, data = ByteArray(10)) + assertEquals("/model/weights.bin", c.path) + assertEquals(4096L, c.fileOffset) + } + + @Test + fun close_isNoOp() { + val c = chunk() + c.close() // should not throw + // Can still read after close (heap-backed, no real resource to release) + assertEquals(0.toByte(), c.readByte(0)) + } +} diff --git a/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/JvmFileBackedResolver.kt b/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/JvmFileBackedResolver.kt new file mode 100644 index 00000000..cf5ea333 --- /dev/null +++ b/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/JvmFileBackedResolver.kt @@ -0,0 +1,62 @@ +package sk.ainet.io + +import sk.ainet.lang.tensor.storage.BufferAccessor +import sk.ainet.lang.tensor.storage.BufferHandle +import sk.ainet.lang.tensor.storage.DefaultBufferResolver +import java.io.File + +/** + * JVM file-backed buffer resolver using memory-mapped I/O. + * + * Resolves [BufferHandle.FileBacked] handles by mapping the referenced + * file region via [JvmMappedMemoryChunk]. The OS manages page-in/out, + * so arbitrarily large weight tensors can be accessed without heap pressure. + * + * Usage: + * ```kotlin + * val resolver = JvmFileBackedResolver.createResolver() + * val accessor = resolver.resolve(fileBackedHandle) + * val bytes = accessor.readBytes(0, 100) + * accessor.close() + * ``` + */ +public object JvmFileBackedResolver { + + /** + * Create a [DefaultBufferResolver] that handles file-backed buffers + * via mmap on JVM. + */ + public fun createResolver(): DefaultBufferResolver = + DefaultBufferResolver(fileBackedResolver = ::resolveFileBacked) + + /** + * Resolve a single file-backed handle to a mmap-backed accessor. + */ + public fun resolveFileBacked(handle: BufferHandle.FileBacked): BufferAccessor { + val chunk = JvmMappedMemoryChunk.open( + File(handle.path), + offset = handle.fileOffset, + length = handle.sizeInBytes + ) + return MappedChunkAccessor(chunk) + } +} + +/** + * [BufferAccessor] backed by a [JvmMappedMemoryChunk]. + * Closing this accessor closes the underlying memory mapping. + */ +internal class MappedChunkAccessor( + private val chunk: JvmMappedMemoryChunk +) : BufferAccessor { + + override val sizeInBytes: Long get() = chunk.size + + override fun readByte(offset: Long): Byte = chunk.readByte(offset) + + override fun readBytes(offset: Long, length: Int): ByteArray = chunk.readBytes(offset, length) + + override fun close() { + chunk.close() + } +} diff --git a/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/JvmMappedMemoryChunk.kt b/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/JvmMappedMemoryChunk.kt new file mode 100644 index 00000000..00a47060 --- /dev/null +++ b/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/JvmMappedMemoryChunk.kt @@ -0,0 +1,98 @@ +package sk.ainet.io + +import java.io.File +import java.io.RandomAccessFile +import java.nio.MappedByteBuffer +import java.nio.channels.FileChannel + +/** + * JVM implementation of [MappedMemoryChunk] using [FileChannel.map]. + * + * The mapped region is read-only and backed by the OS virtual memory + * subsystem. Pages are loaded on demand and evicted under memory pressure, + * so arbitrarily large regions can be mapped without consuming heap. + */ +public class JvmMappedMemoryChunk private constructor( + override val path: String, + override val fileOffset: Long, + override val size: Long, + private val buffer: MappedByteBuffer, + private val raf: RandomAccessFile +) : MappedMemoryChunk { + + override fun readByte(offset: Long): Byte { + require(offset in 0 until size) { "Offset out of bounds: $offset (size=$size)" } + return buffer.get(offset.toInt()) + } + + override fun readBytes(offset: Long, length: Int): ByteArray { + require(offset >= 0 && offset + length <= size) { + "Range out of bounds: offset=$offset length=$length size=$size" + } + val result = ByteArray(length) + // MappedByteBuffer is not thread-safe for positional reads, + // so we use a duplicate to avoid contention on position state. + val dup = buffer.duplicate() + dup.position(offset.toInt()) + dup.get(result, 0, length) + return result + } + + override fun slice(offset: Long, length: Long): MemoryChunk { + require(offset >= 0 && offset + length <= size) { + "Slice out of bounds: offset=$offset length=$length size=$size" + } + val dup = buffer.duplicate() + dup.position(offset.toInt()) + dup.limit((offset + length).toInt()) + val slicedBuffer = dup.slice() as MappedByteBuffer + return JvmMappedMemoryChunk(path, fileOffset + offset, length, slicedBuffer, raf) + } + + override fun close() { + // MappedByteBuffer is unmapped when GC'd; we close the underlying file. + raf.close() + } + + public companion object { + + /** + * Map a region of a file into memory. + * + * @param file The file to map + * @param offset Byte offset within the file (must be non-negative) + * @param length Number of bytes to map (0 = map to end of file) + */ + public fun open(file: File, offset: Long = 0, length: Long = 0): JvmMappedMemoryChunk { + require(file.exists()) { "File not found: ${file.absolutePath}" } + require(file.isFile) { "Not a file: ${file.absolutePath}" } + require(offset >= 0) { "Offset must be non-negative: $offset" } + + val raf = RandomAccessFile(file, "r") + val actualLength = if (length == 0L) raf.length() - offset else length + + require(offset + actualLength <= raf.length()) { + "Mapped region exceeds file: offset=$offset length=$actualLength file=${raf.length()}" + } + + val mapped = raf.channel.map(FileChannel.MapMode.READ_ONLY, offset, actualLength) + return JvmMappedMemoryChunk( + path = file.absolutePath, + fileOffset = offset, + size = actualLength, + buffer = mapped, + raf = raf + ) + } + + /** + * Map a region of a file into memory. + * + * @param path Path to the file + * @param offset Byte offset within the file + * @param length Number of bytes to map (0 = map to end of file) + */ + public fun open(path: String, offset: Long = 0, length: Long = 0): JvmMappedMemoryChunk = + open(File(path), offset, length) + } +} diff --git a/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/MappedRandomAccessSource.kt b/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/MappedRandomAccessSource.kt new file mode 100644 index 00000000..1e92c99b --- /dev/null +++ b/skainet-io/skainet-io-core/src/jvmMain/kotlin/sk/ainet/io/MappedRandomAccessSource.kt @@ -0,0 +1,52 @@ +package sk.ainet.io + +import java.io.File + +/** + * A [RandomAccessSource] backed by a memory-mapped file via [JvmMappedMemoryChunk]. + * + * Unlike [JvmRandomAccessSource] (which reads through a FileChannel into + * heap buffers), this variant lets the OS manage paging. Ideal for immutable + * model weights that are read repeatedly. + */ +public class MappedRandomAccessSource private constructor( + private val chunk: JvmMappedMemoryChunk +) : RandomAccessSource { + + override val size: Long get() = chunk.size + + override fun readAt(position: Long, length: Int): ByteArray = + chunk.readBytes(position, length) + + override fun readAt(position: Long, buffer: ByteArray, offset: Int, length: Int): Int { + require(position >= 0) { "Position must be non-negative: $position" } + require(offset >= 0) { "Offset must be non-negative: $offset" } + require(length >= 0) { "Length must be non-negative: $length" } + require(offset + length <= buffer.size) { + "Buffer overflow: offset=$offset, length=$length, buffer.size=${buffer.size}" + } + + val available = minOf(length.toLong(), size - position).toInt() + if (available <= 0) return 0 + + val bytes = chunk.readBytes(position, available) + bytes.copyInto(buffer, offset) + return available + } + + /** Return a [MemoryChunk] slice without copying — useful for loader integration. */ + public fun sliceChunk(offset: Long, length: Long): MemoryChunk = + chunk.slice(offset, length) + + override fun close() { + chunk.close() + } + + public companion object { + public fun open(file: File): MappedRandomAccessSource = + MappedRandomAccessSource(JvmMappedMemoryChunk.open(file)) + + public fun open(path: String): MappedRandomAccessSource = + MappedRandomAccessSource(JvmMappedMemoryChunk.open(path)) + } +} diff --git a/skainet-io/skainet-io-core/src/jvmTest/kotlin/sk/ainet/io/JvmFileBackedResolverTest.kt b/skainet-io/skainet-io-core/src/jvmTest/kotlin/sk/ainet/io/JvmFileBackedResolverTest.kt new file mode 100644 index 00000000..64816cd3 --- /dev/null +++ b/skainet-io/skainet-io-core/src/jvmTest/kotlin/sk/ainet/io/JvmFileBackedResolverTest.kt @@ -0,0 +1,120 @@ +package sk.ainet.io + +import sk.ainet.lang.tensor.storage.BufferHandle +import sk.ainet.lang.tensor.storage.ByteArrayAccessor +import java.io.File +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class JvmFileBackedResolverTest { + + private fun withTempFile(content: ByteArray, block: (File) -> Unit) { + val file = File.createTempFile("resolver_test_", ".bin") + try { + file.writeBytes(content) + block(file) + } finally { + file.delete() + } + } + + @Test + fun resolveFileBackedHandleReadsFull() { + val data = ByteArray(256) { it.toByte() } + withTempFile(data) { file -> + val handle = BufferHandle.FileBacked( + path = file.absolutePath, + fileOffset = 0, + sizeInBytes = 256 + ) + val resolver = JvmFileBackedResolver.createResolver() + val accessor = resolver.resolve(handle) + + assertEquals(256L, accessor.sizeInBytes) + assertEquals(0.toByte(), accessor.readByte(0)) + assertEquals(255.toByte(), accessor.readByte(255)) + + val bytes = accessor.readBytes(10, 5) + assertEquals(5, bytes.size) + assertEquals(10.toByte(), bytes[0]) + assertEquals(14.toByte(), bytes[4]) + + accessor.close() + } + } + + @Test + fun resolveFileBackedHandleWithOffset() { + val data = ByteArray(1024) { it.toByte() } + withTempFile(data) { file -> + val handle = BufferHandle.FileBacked( + path = file.absolutePath, + fileOffset = 512, + sizeInBytes = 100 + ) + val resolver = JvmFileBackedResolver.createResolver() + val accessor = resolver.resolve(handle) + + assertEquals(100L, accessor.sizeInBytes) + // First byte of the mapped region should be byte 512 of the file + assertEquals(0.toByte(), accessor.readByte(0)) // 512 % 256 = 0 + accessor.close() + } + } + + @Test + fun resolveOwnedHandleDirectly() { + val data = byteArrayOf(10, 20, 30, 40) + val handle = BufferHandle.Owned(data) + val resolver = JvmFileBackedResolver.createResolver() + val accessor = resolver.resolve(handle) + + assertTrue(accessor is ByteArrayAccessor) + assertEquals(4L, accessor.sizeInBytes) + assertEquals(10.toByte(), accessor.readByte(0)) + accessor.close() + } + + @Test + fun resolveBorrowedHandleDirectly() { + val data = byteArrayOf(5, 6, 7) + val handle = BufferHandle.Borrowed(data) + val resolver = JvmFileBackedResolver.createResolver() + val accessor = resolver.resolve(handle) + + assertTrue(accessor is ByteArrayAccessor) + assertEquals(3L, accessor.sizeInBytes) + accessor.close() + } + + @Test + fun resolveAliasedHandle() { + val data = ByteArray(100) { it.toByte() } + val parent = BufferHandle.Owned(data) + val alias = BufferHandle.Aliased(parent, byteOffset = 10, sizeInBytes = 20) + val resolver = JvmFileBackedResolver.createResolver() + val accessor = resolver.resolve(alias) + + assertEquals(20L, accessor.sizeInBytes) + assertEquals(10.toByte(), accessor.readByte(0)) + assertEquals(29.toByte(), accessor.readByte(19)) + accessor.close() + } + + @Test + fun readAllBytesFromFileBacked() { + val data = byteArrayOf(1, 2, 3, 4, 5) + withTempFile(data) { file -> + val handle = BufferHandle.FileBacked(file.absolutePath, 0, 5) + val resolver = JvmFileBackedResolver.createResolver() + val accessor = resolver.resolve(handle) + + val all = accessor.readAllBytes() + assertEquals(5, all.size) + assertEquals(1.toByte(), all[0]) + assertEquals(5.toByte(), all[4]) + accessor.close() + } + } +} diff --git a/skainet-io/skainet-io-core/src/jvmTest/kotlin/sk/ainet/io/JvmMappedMemoryChunkTest.kt b/skainet-io/skainet-io-core/src/jvmTest/kotlin/sk/ainet/io/JvmMappedMemoryChunkTest.kt new file mode 100644 index 00000000..a0bfc2be --- /dev/null +++ b/skainet-io/skainet-io-core/src/jvmTest/kotlin/sk/ainet/io/JvmMappedMemoryChunkTest.kt @@ -0,0 +1,109 @@ +package sk.ainet.io + +import java.io.File +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class JvmMappedMemoryChunkTest { + + private fun withTempFile(content: ByteArray, block: (File) -> Unit) { + val file = File.createTempFile("mmap_test_", ".bin") + try { + file.writeBytes(content) + block(file) + } finally { + file.delete() + } + } + + @Test + fun mapEntireFile() { + val data = ByteArray(256) { it.toByte() } + withTempFile(data) { file -> + JvmMappedMemoryChunk.open(file).use { chunk -> + assertEquals(256L, chunk.size) + assertEquals(0.toByte(), chunk.readByte(0)) + assertEquals(255.toByte(), chunk.readByte(255)) + } + } + } + + @Test + fun mapRegion() { + val data = ByteArray(1024) { it.toByte() } + withTempFile(data) { file -> + JvmMappedMemoryChunk.open(file, offset = 100, length = 200).use { chunk -> + assertEquals(200L, chunk.size) + assertEquals(100.toByte(), chunk.readByte(0)) + assertEquals(101.toByte(), chunk.readByte(1)) + } + } + } + + @Test + fun readBytes() { + val data = ByteArray(64) { (it + 10).toByte() } + withTempFile(data) { file -> + JvmMappedMemoryChunk.open(file).use { chunk -> + val bytes = chunk.readBytes(0, 4) + assertEquals(4, bytes.size) + assertEquals(10.toByte(), bytes[0]) + assertEquals(13.toByte(), bytes[3]) + } + } + } + + @Test + fun sliceReturnsSubRegion() { + val data = ByteArray(128) { it.toByte() } + withTempFile(data) { file -> + JvmMappedMemoryChunk.open(file).use { chunk -> + val slice = chunk.slice(32, 16) + assertEquals(16L, slice.size) + assertEquals(32.toByte(), slice.readByte(0)) + assertEquals(47.toByte(), slice.readByte(15)) + } + } + } + + @Test + fun mappedRandomAccessSourceReads() { + val data = ByteArray(512) { it.toByte() } + withTempFile(data) { file -> + MappedRandomAccessSource.open(file).use { source -> + assertEquals(512L, source.size) + val bytes = source.readAt(100, 10) + assertEquals(10, bytes.size) + assertEquals(100.toByte(), bytes[0]) + } + } + } + + @Test + fun mappedRandomAccessSourceReadIntoBuffer() { + val data = ByteArray(256) { it.toByte() } + withTempFile(data) { file -> + MappedRandomAccessSource.open(file).use { source -> + val buffer = ByteArray(8) + val read = source.readAt(50, buffer, 0, 8) + assertEquals(8, read) + assertEquals(50.toByte(), buffer[0]) + assertEquals(57.toByte(), buffer[7]) + } + } + } + + @Test + fun mappedMemoryChunkProperties() { + val data = ByteArray(100) + withTempFile(data) { file -> + JvmMappedMemoryChunk.open(file, offset = 10, length = 80).use { chunk -> + assertTrue(chunk is MappedMemoryChunk) + assertEquals(file.absolutePath, chunk.path) + assertEquals(10L, chunk.fileOffset) + assertEquals(80L, chunk.size) + } + } + } +} diff --git a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Constants.kt b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Constants.kt index ce72a180..af302309 100644 --- a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Constants.kt +++ b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Constants.kt @@ -5,8 +5,6 @@ package sk.ainet.io.gguf * of github repo "https://github.com/ggerganov/llama.cpp" */ -//TODO convert the rest of file from constants.py - const val GGUF_MAGIC = 0x46554747u const val GGUF_VERSION = 3 const val GGUF_DEFAULT_ALIGNMENT = 32 diff --git a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/GgufParametersLoader.kt b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/GgufParametersLoader.kt index 4b2d12ce..df7f3ba0 100644 --- a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/GgufParametersLoader.kt +++ b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/GgufParametersLoader.kt @@ -11,13 +11,20 @@ import sk.ainet.lang.types.Int32 import kotlin.reflect.KClass /** - * ParametersLoader implementation backed by GGUFReader. + * ParametersLoader implementation backed by the legacy [GGUFReader]. * * Notes: - * - Currently supports loading tensors as FP32 or Int32. Other dtypes can be added as needed. + * - Currently supports loading tensors as FP32 or Int32 only. * - For quantized GGML tensor payloads, this implementation does not perform dequantization and will throw. * - A lightweight progress callback can be provided to observe per-tensor progress (current/total/name). + * + * @see StreamingGgufParametersLoader for the recommended streaming-based loader + * that supports quantized types and memory-efficient parsing. */ +@Deprecated( + message = "Use StreamingGgufParametersLoader for memory-efficient loading with quantized type support", + replaceWith = ReplaceWith("StreamingGgufParametersLoader(sourceProvider, onProgress)") +) class GgufParametersLoader( private val sourceProvider: () -> Source, private val onProgress: (current: Long, total: Long, message: String?) -> Unit = { _, _, _ -> } diff --git a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Quants.kt b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Quants.kt index c2d14f21..ddaeb2cd 100644 --- a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Quants.kt +++ b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/Quants.kt @@ -1,14 +1,32 @@ package sk.ainet.io.gguf /** - * This is a kotlin gguf reader related logic interpreted from python code "gguf-py/gguf/quants.py" - * of github repo "https://github.com/ggerganov/llama.cpp" + * Quantization shape and size utilities for GGUF tensor loading. + * + * Ported from "gguf-py/gguf/quants.py" in llama.cpp. + * These functions handle the mapping between logical element shapes + * and physical byte shapes for quantized tensor formats. + * + * @see [GGML_QUANT_SIZES] for block-size and type-size definitions + * @see [DequantOps][sk.ainet.io.gguf.dequant.DequantOps] for actual dequantization kernels */ -//TODO convert the rest of file from quants.py - +/** + * Convert a logical element shape to a physical byte shape for quantized storage. + * + * The last dimension (row size) must be a multiple of the quantization block size. + * It is replaced by `(row / blockSize) * typeSize` to reflect the packed byte layout. + * + * Example: Q4_K with shape [32, 256] → [32, 144] (256/256 * 144) + * + * @param shape Logical element dimensions + * @param quantType The quantization format + * @return Physical byte dimensions + * @throws IllegalArgumentException if the last dimension is not block-aligned + */ fun quantShapeToByteShape(shape: List, quantType: GGMLQuantizationType): List { - val (blockSize, typeSize) = GGML_QUANT_SIZES[quantType]!! + val (blockSize, typeSize) = GGML_QUANT_SIZES[quantType] + ?: throw IllegalArgumentException("Unknown quantization type: ${quantType.name}") if (shape.last().toInt() % blockSize != 0) { throw IllegalArgumentException( "Quantized tensor row size (${shape.last()}) is not a multiple of ${quantType.name} block size ($blockSize)" @@ -18,3 +36,111 @@ fun quantShapeToByteShape(shape: List, quantType: GGMLQuantizationType): val newShape = shape.dropLast(1) + (shape.last() / blockSize.toULong() * typeSize.toULong()) return newShape } + +/** + * Convert a physical byte shape back to a logical element shape. + * + * Inverse of [quantShapeToByteShape]. The last dimension (byte row size) + * must be a multiple of the type size. It is replaced by + * `(byteRow / typeSize) * blockSize`. + * + * Example: Q4_K with byte shape [32, 144] → [32, 256] + * + * @param byteShape Physical byte dimensions + * @param quantType The quantization format + * @return Logical element dimensions + * @throws IllegalArgumentException if the last dimension is not aligned to type size + */ +fun byteShapeToQuantShape(byteShape: List, quantType: GGMLQuantizationType): List { + val (blockSize, typeSize) = GGML_QUANT_SIZES[quantType] + ?: throw IllegalArgumentException("Unknown quantization type: ${quantType.name}") + if (byteShape.last().toInt() % typeSize != 0) { + throw IllegalArgumentException( + "Byte row size (${byteShape.last()}) is not a multiple of ${quantType.name} type size ($typeSize)" + ) + } + + val newShape = byteShape.dropLast(1) + (byteShape.last() / typeSize.toULong() * blockSize.toULong()) + return newShape +} + +/** + * Compute the total number of logical elements from a shape. + * + * @param shape Logical element dimensions + * @return Product of all dimensions, or 1 for a scalar (empty shape) + */ +fun quantElementCount(shape: List): ULong { + if (shape.isEmpty()) return 1u + return shape.fold(1UL) { acc, dim -> acc * dim } +} + +/** + * Compute the total byte size for a quantized tensor. + * + * @param elementCount Total number of logical elements + * @param quantType The quantization format + * @return Number of bytes required to store the tensor + * @throws IllegalArgumentException if the element count is not block-aligned + */ +fun quantByteSize(elementCount: ULong, quantType: GGMLQuantizationType): ULong { + val (blockSize, typeSize) = GGML_QUANT_SIZES[quantType] + ?: throw IllegalArgumentException("Unknown quantization type: ${quantType.name}") + if (elementCount.toInt() % blockSize != 0) { + throw IllegalArgumentException( + "Element count ($elementCount) is not a multiple of ${quantType.name} block size ($blockSize)" + ) + } + return elementCount / blockSize.toULong() * typeSize.toULong() +} + +/** + * Check whether a quantization type uses block quantization (vs element-wise). + * + * Block-quantized types pack multiple elements per block with shared + * scale/offset metadata. Element-wise types (F32, F16, I8, etc.) have + * a block size of 1. + * + * @param quantType The quantization format + * @return true if block size > 1 + */ +fun isBlockQuantized(quantType: GGMLQuantizationType): Boolean { + val (blockSize, _) = GGML_QUANT_SIZES[quantType] ?: return false + return blockSize > 1 +} + +/** + * Get the block size for a quantization type. + * + * @param quantType The quantization format + * @return Number of elements per block, or null if unknown + */ +fun quantBlockSize(quantType: GGMLQuantizationType): Int? { + return GGML_QUANT_SIZES[quantType]?.first +} + +/** + * Get the byte size per block for a quantization type. + * + * @param quantType The quantization format + * @return Number of bytes per block, or null if unknown + */ +fun quantTypeSize(quantType: GGMLQuantizationType): Int? { + return GGML_QUANT_SIZES[quantType]?.second +} + +/** + * Validate that a byte array has the correct size for a given quantized tensor. + * + * @param bytes Raw byte data + * @param elementCount Number of logical elements + * @param quantType The quantization format + * @throws IllegalArgumentException if the size doesn't match + */ +fun validateQuantizedBytes(bytes: ByteArray, elementCount: ULong, quantType: GGMLQuantizationType) { + val expectedBytes = quantByteSize(elementCount, quantType) + require(bytes.size.toULong() == expectedBytes) { + "Byte array size (${bytes.size}) does not match expected size ($expectedBytes) " + + "for $elementCount elements of type ${quantType.name}" + } +} diff --git a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGGUFReader.kt b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGGUFReader.kt index f892c4d5..cc07a2cf 100644 --- a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGGUFReader.kt +++ b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGGUFReader.kt @@ -1,6 +1,8 @@ package sk.ainet.io.gguf import sk.ainet.io.RandomAccessSource +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.storage.* /** * Streaming GGUF reader that parses metadata without loading the entire file. @@ -97,6 +99,93 @@ public class StreamingGGUFReader private constructor( return source.readAt(tensor.absoluteDataOffset, buffer, offset, tensor.nBytes) } + // ========== TensorStorage Loading ========== + + /** + * Load a tensor as a [TensorStorage] descriptor with borrowed bytes. + * The returned storage borrows the loaded byte array (no extra copy). + */ + public fun loadTensorStorage(tensor: StreamingTensorInfo): TensorStorage { + val bytes = loadTensorData(tensor) + val shape = Shape(*tensor.shape.map { it.toInt() }.toIntArray()) + return TensorStorage( + shape = shape, + logicalType = ggmlTypeToLogical(tensor.tensorType), + encoding = ggmlTypeToEncoding(tensor.tensorType), + buffer = BufferHandle.Borrowed(bytes, isMutable = false), + placement = Placement.CPU_HEAP + ) + } + + /** + * Load a tensor by name as a [TensorStorage] descriptor. + */ + public fun loadTensorStorage(name: String): TensorStorage { + val tensor = _tensors.firstOrNull { it.name == name } + ?: throw IllegalArgumentException("Tensor not found: $name") + return loadTensorStorage(tensor) + } + + /** + * Create a file-backed [TensorStorage] that references the tensor's bytes + * in the original file without loading them into heap. + * + * Requires the source to be file-based. The returned storage uses + * [BufferHandle.FileBacked] with the tensor's absolute file offset. + * + * @param tensor The tensor info from [tensors] list + * @param filePath Path to the GGUF file (needed for the FileBacked handle) + */ + public fun loadTensorStorageMapped(tensor: StreamingTensorInfo, filePath: String): TensorStorage { + val shape = Shape(*tensor.shape.map { it.toInt() }.toIntArray()) + return TensorStorage( + shape = shape, + logicalType = ggmlTypeToLogical(tensor.tensorType), + encoding = ggmlTypeToEncoding(tensor.tensorType), + buffer = BufferHandle.FileBacked( + path = filePath, + fileOffset = tensor.absoluteDataOffset, + sizeInBytes = tensor.nBytes.toLong() + ), + placement = Placement.MMAP_WEIGHTS + ) + } + + private fun ggmlTypeToLogical(type: GGMLQuantizationType): LogicalDType = when (type) { + GGMLQuantizationType.F32 -> LogicalDType.FLOAT32 + GGMLQuantizationType.F16 -> LogicalDType.FLOAT16 + GGMLQuantizationType.BF16 -> LogicalDType.BFLOAT16 + GGMLQuantizationType.F64 -> LogicalDType.FLOAT64 + GGMLQuantizationType.I8 -> LogicalDType.INT8 + GGMLQuantizationType.I16 -> LogicalDType.INT16 + GGMLQuantizationType.I32 -> LogicalDType.INT32 + GGMLQuantizationType.I64 -> LogicalDType.INT64 + // Quantized types logically represent floats + else -> LogicalDType.FLOAT32 + } + + private fun ggmlTypeToEncoding(type: GGMLQuantizationType): TensorEncoding = when (type) { + GGMLQuantizationType.F32 -> TensorEncoding.Dense(4) + GGMLQuantizationType.F16 -> TensorEncoding.Dense(2) + GGMLQuantizationType.BF16 -> TensorEncoding.Dense(2) + GGMLQuantizationType.F64 -> TensorEncoding.Dense(8) + GGMLQuantizationType.I8 -> TensorEncoding.Dense(1) + GGMLQuantizationType.I16 -> TensorEncoding.Dense(2) + GGMLQuantizationType.I32 -> TensorEncoding.Dense(4) + GGMLQuantizationType.I64 -> TensorEncoding.Dense(8) + GGMLQuantizationType.Q4_K -> TensorEncoding.Q4_K + GGMLQuantizationType.Q8_0 -> TensorEncoding.Q8_0 + else -> { + // For other quantized types, use Opaque with raw byte count + val quantInfo = GGML_QUANT_SIZES[type] + if (quantInfo != null) { + TensorEncoding.Opaque(type.name, 0) // size computed from tensor info + } else { + TensorEncoding.Opaque(type.name, 0) + } + } + } + // ========== Parsing Implementation ========== private fun parse() { diff --git a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoader.kt b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoader.kt new file mode 100644 index 00000000..00be76c2 --- /dev/null +++ b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoader.kt @@ -0,0 +1,173 @@ +package sk.ainet.io.gguf + +import sk.ainet.context.ExecutionContext +import sk.ainet.io.ParametersLoader +import sk.ainet.io.RandomAccessSource +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.Q4_KBlockTensorData +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 +import kotlin.reflect.KClass + +/** + * Streaming GGUF parameters loader — the recommended path for loading GGUF models. + * + * Unlike [GgufParametersLoader] (which uses the legacy [GGUFReader] and rejects + * quantized types), this loader: + * - Uses [StreamingGGUFReader] for memory-efficient parsing + * - Supports quantized types (Q4_K, Q8_0) as packed [TensorData] + * - Loads tensor data on-demand without heap-loading the full file + * - Preserves quantized layout through the loading pipeline + * + * For F32 and I32 tensors, data is returned as standard dense arrays. + * For quantized tensors, data is returned as packed block storage + * (e.g., [Q4_KBlockTensorData], [Q8_0BlockTensorData]). + */ +public class StreamingGgufParametersLoader( + private val sourceProvider: () -> RandomAccessSource, + private val onProgress: (current: Long, total: Long, message: String?) -> Unit = { _, _, _ -> } +) : ParametersLoader { + + @Suppress("UNCHECKED_CAST") + override suspend fun load( + ctx: ExecutionContext, + dtype: KClass, + onTensorLoaded: (String, Tensor) -> Unit + ) { + StreamingGGUFReader.open(sourceProvider()).use { reader -> + val tensors = reader.tensors + val total = tensors.size.toLong() + var current = 0L + + for (tensorInfo in tensors) { + val shape = Shape(*tensorInfo.shape.map { it.toInt() }.toIntArray()) + val rawBytes = reader.loadTensorData(tensorInfo) + + val tensor: Tensor? = when (tensorInfo.tensorType) { + GGMLQuantizationType.F32 -> { + val floats = bytesToFloatArray(rawBytes) + when (dtype) { + FP32::class -> ctx.fromFloatArray(shape, dtype, floats) as Tensor + else -> null + } + } + + GGMLQuantizationType.I32 -> { + val ints = bytesToIntArray(rawBytes) + when (dtype) { + Int32::class -> ctx.fromIntArray(shape, dtype, ints) as Tensor + else -> null + } + } + + GGMLQuantizationType.F16 -> { + val floats = dequantF16(rawBytes) + when (dtype) { + FP32::class -> ctx.fromFloatArray(shape, dtype, floats) as Tensor + else -> null + } + } + + GGMLQuantizationType.BF16 -> { + val floats = dequantBF16(rawBytes) + when (dtype) { + FP32::class -> ctx.fromFloatArray(shape, dtype, floats) as Tensor + else -> null + } + } + + GGMLQuantizationType.Q4_K -> { + @Suppress("UNCHECKED_CAST") + val packed = Q4_KBlockTensorData.fromRawBytes(shape, rawBytes) + ctx.fromData(packed as sk.ainet.lang.tensor.data.TensorData, dtype) + } + + GGMLQuantizationType.Q8_0 -> { + @Suppress("UNCHECKED_CAST") + val packed = Q8_0BlockTensorData.fromRawBytes(shape, rawBytes) + ctx.fromData(packed as sk.ainet.lang.tensor.data.TensorData, dtype) + } + + else -> { + onProgress(current, total, "SKIP: ${tensorInfo.name} (unsupported type ${tensorInfo.tensorType})") + null + } + } + + if (tensor != null) { + onTensorLoaded(tensorInfo.name, tensor) + } + + current += 1 + onProgress(current, total, tensorInfo.name) + } + } + } + + private fun bytesToFloatArray(bytes: ByteArray): FloatArray { + val count = bytes.size / 4 + return FloatArray(count) { i -> + val off = i * 4 + Float.fromBits( + (bytes[off].toInt() and 0xFF) or + ((bytes[off + 1].toInt() and 0xFF) shl 8) or + ((bytes[off + 2].toInt() and 0xFF) shl 16) or + ((bytes[off + 3].toInt() and 0xFF) shl 24) + ) + } + } + + private fun bytesToIntArray(bytes: ByteArray): IntArray { + val count = bytes.size / 4 + return IntArray(count) { i -> + val off = i * 4 + (bytes[off].toInt() and 0xFF) or + ((bytes[off + 1].toInt() and 0xFF) shl 8) or + ((bytes[off + 2].toInt() and 0xFF) shl 16) or + ((bytes[off + 3].toInt() and 0xFF) shl 24) + } + } + + private fun dequantF16(bytes: ByteArray): FloatArray { + val count = bytes.size / 2 + return FloatArray(count) { i -> + val off = i * 2 + val halfBits = (bytes[off].toInt() and 0xFF) or + ((bytes[off + 1].toInt() and 0xFF) shl 8) + halfToFloat(halfBits) + } + } + + private fun dequantBF16(bytes: ByteArray): FloatArray { + val count = bytes.size / 2 + return FloatArray(count) { i -> + val off = i * 2 + val bf16Bits = (bytes[off].toInt() and 0xFF) or + ((bytes[off + 1].toInt() and 0xFF) shl 8) + Float.fromBits(bf16Bits shl 16) + } + } + + private fun halfToFloat(hbits: Int): Float { + val sign = (hbits and 0x8000) shl 16 + val exp = (hbits and 0x7C00) shr 10 + val mant = hbits and 0x03FF + + return when (exp) { + 0 -> { + if (mant == 0) Float.fromBits(sign) + else { + var m = mant; var e = -14 + while ((m and 0x400) == 0) { m = m shl 1; e-- } + m = m and 0x3FF + Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13)) + } + } + 31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13)) + else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13)) + } + } +} diff --git a/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/QuantsTest.kt b/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/QuantsTest.kt new file mode 100644 index 00000000..fbb54e74 --- /dev/null +++ b/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/QuantsTest.kt @@ -0,0 +1,180 @@ +package sk.ainet.io.gguf + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +/** + * Tests for quantization shape and size utilities in [Quants.kt]. + */ +class QuantsTest { + + // --- quantShapeToByteShape --- + + @Test + fun quantShapeToByteShape_Q4_K() { + val shape = listOf(32UL, 256UL) + val result = quantShapeToByteShape(shape, GGMLQuantizationType.Q4_K) + // Q4_K: blockSize=256, typeSize=144 → 256/256 * 144 = 144 + assertEquals(listOf(32UL, 144UL), result) + } + + @Test + fun quantShapeToByteShape_Q8_0() { + val shape = listOf(128UL) + val result = quantShapeToByteShape(shape, GGMLQuantizationType.Q8_0) + // Q8_0: blockSize=32, typeSize=34 → 128/32 * 34 = 136 + assertEquals(listOf(136UL), result) + } + + @Test + fun quantShapeToByteShape_F32_passthrough() { + val shape = listOf(10UL, 20UL) + val result = quantShapeToByteShape(shape, GGMLQuantizationType.F32) + // F32: blockSize=1, typeSize=4 → 20/1 * 4 = 80 + assertEquals(listOf(10UL, 80UL), result) + } + + @Test + fun quantShapeToByteShape_unaligned_throws() { + assertFailsWith { + quantShapeToByteShape(listOf(100UL), GGMLQuantizationType.Q4_K) // 100 not multiple of 256 + } + } + + // --- byteShapeToQuantShape --- + + @Test + fun byteShapeToQuantShape_Q4_K() { + val byteShape = listOf(32UL, 144UL) + val result = byteShapeToQuantShape(byteShape, GGMLQuantizationType.Q4_K) + assertEquals(listOf(32UL, 256UL), result) + } + + @Test + fun byteShapeToQuantShape_Q8_0() { + val byteShape = listOf(136UL) + val result = byteShapeToQuantShape(byteShape, GGMLQuantizationType.Q8_0) + assertEquals(listOf(128UL), result) + } + + @Test + fun byteShapeToQuantShape_roundTrip() { + val original = listOf(16UL, 512UL) + val byteShape = quantShapeToByteShape(original, GGMLQuantizationType.Q4_K) + val recovered = byteShapeToQuantShape(byteShape, GGMLQuantizationType.Q4_K) + assertEquals(original, recovered) + } + + @Test + fun byteShapeToQuantShape_unaligned_throws() { + assertFailsWith { + byteShapeToQuantShape(listOf(100UL), GGMLQuantizationType.Q8_0) // 100 not multiple of 34 + } + } + + // --- quantElementCount --- + + @Test + fun quantElementCount_standard() { + assertEquals(1024UL, quantElementCount(listOf(32UL, 32UL))) + } + + @Test + fun quantElementCount_scalar() { + assertEquals(1UL, quantElementCount(emptyList())) + } + + @Test + fun quantElementCount_1d() { + assertEquals(256UL, quantElementCount(listOf(256UL))) + } + + // --- quantByteSize --- + + @Test + fun quantByteSize_Q4_K() { + // 256 elements → 1 block → 144 bytes + assertEquals(144UL, quantByteSize(256UL, GGMLQuantizationType.Q4_K)) + } + + @Test + fun quantByteSize_Q8_0() { + // 64 elements → 2 blocks → 68 bytes + assertEquals(68UL, quantByteSize(64UL, GGMLQuantizationType.Q8_0)) + } + + @Test + fun quantByteSize_F32() { + assertEquals(40UL, quantByteSize(10UL, GGMLQuantizationType.F32)) + } + + // --- isBlockQuantized --- + + @Test + fun isBlockQuantized_true() { + assertTrue(isBlockQuantized(GGMLQuantizationType.Q4_K)) + assertTrue(isBlockQuantized(GGMLQuantizationType.Q8_0)) + assertTrue(isBlockQuantized(GGMLQuantizationType.Q2_K)) + assertTrue(isBlockQuantized(GGMLQuantizationType.TQ2_0)) + } + + @Test + fun isBlockQuantized_false() { + assertFalse(isBlockQuantized(GGMLQuantizationType.F32)) + assertFalse(isBlockQuantized(GGMLQuantizationType.F16)) + assertFalse(isBlockQuantized(GGMLQuantizationType.I8)) + } + + // --- quantBlockSize / quantTypeSize --- + + @Test + fun quantBlockSize_known() { + assertEquals(256, quantBlockSize(GGMLQuantizationType.Q4_K)) + assertEquals(32, quantBlockSize(GGMLQuantizationType.Q8_0)) + assertEquals(1, quantBlockSize(GGMLQuantizationType.F32)) + } + + @Test + fun quantTypeSize_known() { + assertEquals(144, quantTypeSize(GGMLQuantizationType.Q4_K)) + assertEquals(34, quantTypeSize(GGMLQuantizationType.Q8_0)) + assertEquals(4, quantTypeSize(GGMLQuantizationType.F32)) + } + + @Test + fun quantBlockSize_unknown_returns_null() { + assertEquals(null, quantBlockSize(GGMLQuantizationType.UNKNOWN)) + } + + // --- validateQuantizedBytes --- + + @Test + fun validateQuantizedBytes_correct_size() { + val bytes = ByteArray(144) // 1 Q4_K block + validateQuantizedBytes(bytes, 256UL, GGMLQuantizationType.Q4_K) + } + + @Test + fun validateQuantizedBytes_wrong_size_throws() { + assertFailsWith { + validateQuantizedBytes(ByteArray(100), 256UL, GGMLQuantizationType.Q4_K) + } + } + + // --- Coverage for all quant types in GGML_QUANT_SIZES --- + + @Test + fun allQuantSizesHaveBlockAndTypeSize() { + for ((type, sizes) in GGML_QUANT_SIZES) { + val (blockSize, typeSize) = sizes + assertTrue(blockSize > 0, "Block size for $type must be positive") + assertTrue(typeSize > 0, "Type size for $type must be positive") + assertNotNull(quantBlockSize(type)) + assertNotNull(quantTypeSize(type)) + } + } +} diff --git a/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/StorageIntegrationTest.kt b/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/StorageIntegrationTest.kt new file mode 100644 index 00000000..29cc46c2 --- /dev/null +++ b/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/StorageIntegrationTest.kt @@ -0,0 +1,225 @@ +package sk.ainet.io.gguf + +import org.junit.Test +import sk.ainet.io.JvmFileBackedResolver +import sk.ainet.io.JvmRandomAccessSource +import sk.ainet.lang.tensor.storage.* +import java.io.File +import java.io.RandomAccessFile +import java.nio.ByteBuffer +import java.nio.ByteOrder +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Integration tests that exercise the full storage pipeline: + * GGUF file → StreamingGGUFReader → TensorStorage → BufferAccessor + * + * Uses a synthetically constructed minimal GGUF file with: + * - 1 F32 tensor (4 elements, 16 bytes) + * - 1 Q8_0 tensor (32 elements, 34 bytes) + */ +class StorageIntegrationTest { + + private fun createTestGgufFile(): File { + val file = File.createTempFile("storage_test_", ".gguf") + RandomAccessFile(file, "rw").use { raf -> + val buf = ByteBuffer.allocate(4096).order(ByteOrder.LITTLE_ENDIAN) + + // === Header === + // Magic: GGUF + buf.putInt(0x46554747.toInt()) + // Version: 3 + buf.putInt(3) + // Tensor count: 2 + buf.putLong(2) + // KV count: 1 + buf.putLong(1) + + // === KV Section === + // Key: "general.architecture" = "test" + val key = "general.architecture".encodeToByteArray() + buf.putLong(key.size.toLong()) // key length + buf.put(key) + buf.putInt(GGUFValueType.STRING.value) // value type + val value = "test".encodeToByteArray() + buf.putLong(value.size.toLong()) // string length + buf.put(value) + + // === Tensor Info Section === + // Tensor 1: "weight_f32", F32, shape [4], 16 bytes + val name1 = "weight_f32".encodeToByteArray() + buf.putLong(name1.size.toLong()) + buf.put(name1) + buf.putInt(1) // n_dims + buf.putLong(4) // dim[0] + buf.putInt(GGMLQuantizationType.F32.value) // type + buf.putLong(0) // relative offset = 0 + + // Tensor 2: "weight_q80", Q8_0, shape [32], 34 bytes + val name2 = "weight_q80".encodeToByteArray() + buf.putLong(name2.size.toLong()) + buf.put(name2) + buf.putInt(1) // n_dims + buf.putLong(32) // dim[0] + buf.putInt(GGMLQuantizationType.Q8_0.value) // type + buf.putLong(16) // relative offset = 16 (after the F32 tensor) + + // === Alignment padding === + val currentPos = buf.position() + val alignment = 32 + val padding = (alignment - (currentPos % alignment)) % alignment + for (i in 0 until padding) buf.put(0) + + // === Data Section === + // F32 tensor data: [1.0, 2.0, 3.0, 4.0] + buf.putFloat(1.0f) + buf.putFloat(2.0f) + buf.putFloat(3.0f) + buf.putFloat(4.0f) + + // Q8_0 tensor data: 1 block = 2 bytes scale + 32 bytes codes + // Scale = 1.0 in f16 = 0x3C00 little-endian + buf.put(0x00.toByte()) + buf.put(0x3C.toByte()) + // Codes: 1, 2, 3, ... 32 + for (i in 1..32) buf.put(i.toByte()) + + // Write to file + buf.flip() + val bytes = ByteArray(buf.remaining()) + buf.get(bytes) + raf.write(bytes) + } + return file + } + + @Test + fun `streaming reader loads TensorStorage with correct metadata`() { + val file = createTestGgufFile() + try { + JvmRandomAccessSource.open(file).use { source -> + val reader = StreamingGGUFReader.open(source) + assertEquals(2, reader.tensors.size.toInt()) + + // F32 tensor + val f32Storage = reader.loadTensorStorage("weight_f32") + assertEquals(LogicalDType.FLOAT32, f32Storage.logicalType) + assertEquals(TensorEncoding.Dense(4), f32Storage.encoding) + assertEquals(Ownership.BORROWED, f32Storage.ownership) + assertEquals(16L, f32Storage.physicalBytes) + assertEquals(4L, f32Storage.elementCount) + assertFalse(f32Storage.isFileBacked) + + // Q8_0 tensor + val q80Storage = reader.loadTensorStorage("weight_q80") + assertEquals(LogicalDType.FLOAT32, q80Storage.logicalType) + assertEquals(TensorEncoding.Q8_0, q80Storage.encoding) + assertEquals(Ownership.BORROWED, q80Storage.ownership) + assertEquals(34L, q80Storage.physicalBytes) + assertEquals(32L, q80Storage.elementCount) + } + } finally { + file.delete() + } + } + + @Test + fun `file-backed storage resolves through mmap`() { + val file = createTestGgufFile() + try { + JvmRandomAccessSource.open(file).use { source -> + val reader = StreamingGGUFReader.open(source) + + // Get file-backed storage + val storage = reader.loadTensorStorageMapped( + reader.tensors.first { it.name == "weight_f32" }, + file.absolutePath + ) + + assertTrue(storage.isFileBacked) + assertEquals(Ownership.FILE_BACKED, storage.ownership) + assertEquals(Placement.MMAP_WEIGHTS, storage.placement) + assertFalse(storage.isMutable) + + // Resolve through mmap and read actual bytes + val resolver = JvmFileBackedResolver.createResolver() + val accessor = resolver.resolve(storage.buffer) + assertEquals(16L, accessor.sizeInBytes) + + // Read F32 values: should be 1.0, 2.0, 3.0, 4.0 + val bytes = accessor.readAllBytes() + val bb = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + assertEquals(1.0f, bb.getFloat(0)) + assertEquals(2.0f, bb.getFloat(4)) + assertEquals(3.0f, bb.getFloat(8)) + assertEquals(4.0f, bb.getFloat(12)) + + accessor.close() + } + } finally { + file.delete() + } + } + + @Test + fun `Q8_0 file-backed storage reads packed block data correctly`() { + val file = createTestGgufFile() + try { + JvmRandomAccessSource.open(file).use { source -> + val reader = StreamingGGUFReader.open(source) + + val storage = reader.loadTensorStorageMapped( + reader.tensors.first { it.name == "weight_q80" }, + file.absolutePath + ) + + assertTrue(storage.isFileBacked) + assertEquals(TensorEncoding.Q8_0, storage.encoding) + + val resolver = JvmFileBackedResolver.createResolver() + val accessor = resolver.resolve(storage.buffer) + assertEquals(34L, accessor.sizeInBytes) + + // First 2 bytes: f16 scale (1.0 = 0x3C00) + assertEquals(0x00.toByte(), accessor.readByte(0)) + assertEquals(0x3C.toByte(), accessor.readByte(1)) + // Code bytes: 1, 2, 3... + assertEquals(1.toByte(), accessor.readByte(2)) + assertEquals(32.toByte(), accessor.readByte(33)) + + accessor.close() + } + } finally { + file.delete() + } + } + + @Test + fun `memory report shows correct metrics for mixed model`() { + val file = createTestGgufFile() + try { + JvmRandomAccessSource.open(file).use { source -> + val reader = StreamingGGUFReader.open(source) + val tracker = MemoryTracker() + + for (tensor in reader.tensors) { + val storage = reader.loadTensorStorage(tensor) + tracker.record(tensor.name, storage) + } + + val report = tracker.report() + assertEquals(2, report.tensorCount) + assertEquals(2, report.borrowedCount) + assertEquals(0, report.ownedCount) + // F32: 4*4=16 logical, 16 physical + // Q8_0: 32*4=128 logical, 34 physical + assertEquals(16L + 128L, report.totalLogicalBytes) + assertEquals(16L + 34L, report.totalPhysicalBytes) + } + } finally { + file.delete() + } + } +} diff --git a/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoaderTest.kt b/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoaderTest.kt new file mode 100644 index 00000000..657fc76a --- /dev/null +++ b/skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoaderTest.kt @@ -0,0 +1,161 @@ +package sk.ainet.io.gguf + +import org.junit.Test +import sk.ainet.context.DefaultDataExecutionContext +import sk.ainet.io.JvmRandomAccessSource +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.tensor.data.Q8_0TensorData +import sk.ainet.lang.tensor.storage.PackedBlockStorage +import sk.ainet.lang.types.FP32 +import java.io.File +import java.io.RandomAccessFile +import java.nio.ByteBuffer +import java.nio.ByteOrder +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class StreamingGgufParametersLoaderTest { + + /** + * Build a minimal GGUF file with F32 and Q8_0 tensors. + * Reuses the approach from StorageIntegrationTest. + */ + private fun createTestGgufFile(): File { + val file = File.createTempFile("loader_test_", ".gguf") + RandomAccessFile(file, "rw").use { raf -> + val buf = ByteBuffer.allocate(4096).order(ByteOrder.LITTLE_ENDIAN) + + buf.putInt(0x46554747.toInt()) // Magic + buf.putInt(3) // Version + buf.putLong(2) // Tensor count + buf.putLong(1) // KV count + + // KV: "general.architecture" = "test" + val key = "general.architecture".encodeToByteArray() + buf.putLong(key.size.toLong()) + buf.put(key) + buf.putInt(GGUFValueType.STRING.value) + val value = "test".encodeToByteArray() + buf.putLong(value.size.toLong()) + buf.put(value) + + // Tensor 1: "weight_f32", F32, shape [4] + val name1 = "weight_f32".encodeToByteArray() + buf.putLong(name1.size.toLong()) + buf.put(name1) + buf.putInt(1) + buf.putLong(4) + buf.putInt(GGMLQuantizationType.F32.value) + buf.putLong(0) + + // Tensor 2: "weight_q80", Q8_0, shape [32] + val name2 = "weight_q80".encodeToByteArray() + buf.putLong(name2.size.toLong()) + buf.put(name2) + buf.putInt(1) + buf.putLong(32) + buf.putInt(GGMLQuantizationType.Q8_0.value) + buf.putLong(16) + + // Alignment padding + val padding = (32 - (buf.position() % 32)) % 32 + for (i in 0 until padding) buf.put(0) + + // F32 data: [1.0, 2.0, 3.0, 4.0] + buf.putFloat(1.0f) + buf.putFloat(2.0f) + buf.putFloat(3.0f) + buf.putFloat(4.0f) + + // Q8_0 data: scale=1.0 (f16 0x3C00) + codes 1..32 + buf.put(0x00.toByte()) + buf.put(0x3C.toByte()) + for (i in 1..32) buf.put(i.toByte()) + + buf.flip() + val bytes = ByteArray(buf.remaining()) + buf.get(bytes) + raf.write(bytes) + } + return file + } + + @Test + fun `load F32 tensor produces dense float tensor`() { + val file = createTestGgufFile() + try { + val ctx = DefaultDataExecutionContext() + val loaded = mutableMapOf>() + + kotlinx.coroutines.runBlocking { + StreamingGgufParametersLoader( + sourceProvider = { JvmRandomAccessSource.open(file) } + ).load(ctx, FP32::class) { name, tensor -> + loaded[name] = tensor + } + } + + assertTrue("weight_f32" in loaded) + val t = loaded["weight_f32"]!! + assertEquals(Shape(4), t.shape) + assertTrue(t.data is FloatArrayTensorData<*>) + val buf = (t.data as FloatArrayTensorData<*>).buffer + assertEquals(1.0f, buf[0]) + assertEquals(4.0f, buf[3]) + } finally { + file.delete() + } + } + + @Test + fun `load Q8_0 tensor produces packed block TensorData`() { + val file = createTestGgufFile() + try { + val ctx = DefaultDataExecutionContext() + val loaded = mutableMapOf>() + + kotlinx.coroutines.runBlocking { + StreamingGgufParametersLoader( + sourceProvider = { JvmRandomAccessSource.open(file) } + ).load(ctx, FP32::class) { name, tensor -> + loaded[name] = tensor + } + } + + assertTrue("weight_q80" in loaded) + val t = loaded["weight_q80"]!! + assertEquals(Shape(32), t.shape) + // Q8_0 data should be packed, implementing PackedBlockStorage + assertTrue(t.data is PackedBlockStorage, "Q8_0 tensor should be PackedBlockStorage") + } finally { + file.delete() + } + } + + @Test + fun `progress callback invoked correctly`() { + val file = createTestGgufFile() + try { + val ctx = DefaultDataExecutionContext() + val progressCalls = mutableListOf>() + + kotlinx.coroutines.runBlocking { + StreamingGgufParametersLoader( + sourceProvider = { JvmRandomAccessSource.open(file) }, + onProgress = { current, total, msg -> progressCalls.add(Triple(current, total, msg)) } + ).load(ctx, FP32::class) { _, _ -> } + } + + // 2 tensors → 2 progress calls + assertEquals(2, progressCalls.size) + assertEquals(1L, progressCalls[0].first) + assertEquals(2L, progressCalls[0].second) + assertEquals(2L, progressCalls[1].first) + assertEquals(2L, progressCalls[1].second) + } finally { + file.delete() + } + } +} diff --git a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt index c293c0ff..1aff343c 100644 --- a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt +++ b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt @@ -25,6 +25,10 @@ import kotlin.reflect.KClass * - I8/U8 tensors -> Int8 * - F16/BF16 tensors -> FP32 (with dequantization) * + * Where possible, decoded arrays are wrapped (borrowed) rather than copied + * into TensorData, avoiding a second allocation. The raw-byte decode step + * (little-endian bytes → typed array) is still necessary. + * * @param sourceProvider Factory providing RandomAccessSource to the SafeTensors file * @param onProgress Optional progress callback (current, total, tensorName) */ @@ -54,7 +58,8 @@ class SafeTensorsParametersLoader( "SafeTensors F32 tensor '${tensorInfo.name}' requires FP32 dtype, got ${dtype.simpleName}" } val floats = bytesToFloatArray(bytes) - ctx.fromFloatArray(shape, dtype, floats) as Tensor + // Wrap the decoded array (zero-copy) — it was freshly allocated by bytesToFloatArray + ctx.wrapFloatArray(shape, dtype, floats) as Tensor } DataType.FLOAT64 -> { @@ -64,7 +69,7 @@ class SafeTensorsParametersLoader( println("WARNING: Downcasting F64 tensor '${tensorInfo.name}' to F32") val doubles = bytesToDoubleArray(bytes) val floats = FloatArray(doubles.size) { doubles[it].toFloat() } - ctx.fromFloatArray(shape, dtype, floats) as Tensor + ctx.wrapFloatArray(shape, dtype, floats) as Tensor } DataType.FLOAT16 -> { @@ -72,7 +77,7 @@ class SafeTensorsParametersLoader( "SafeTensors F16 tensor '${tensorInfo.name}' requires FP32 dtype (dequant), got ${dtype.simpleName}" } val floats = dequantF16(bytes) - ctx.fromFloatArray(shape, dtype, floats) as Tensor + ctx.wrapFloatArray(shape, dtype, floats) as Tensor } DataType.BFLOAT16 -> { @@ -80,7 +85,7 @@ class SafeTensorsParametersLoader( "SafeTensors BF16 tensor '${tensorInfo.name}' requires FP32 dtype (dequant), got ${dtype.simpleName}" } val floats = dequantBF16(bytes) - ctx.fromFloatArray(shape, dtype, floats) as Tensor + ctx.wrapFloatArray(shape, dtype, floats) as Tensor } DataType.INT32 -> { @@ -88,7 +93,7 @@ class SafeTensorsParametersLoader( "SafeTensors I32 tensor '${tensorInfo.name}' requires Int32 dtype, got ${dtype.simpleName}" } val ints = bytesToIntArray(bytes) - ctx.fromIntArray(shape, dtype, ints) as Tensor + ctx.wrapIntArray(shape, dtype, ints) as Tensor } DataType.INT64 -> { @@ -98,7 +103,7 @@ class SafeTensorsParametersLoader( println("WARNING: Downcasting I64 tensor '${tensorInfo.name}' to I32") val longs = bytesToLongArray(bytes) val ints = IntArray(longs.size) { longs[it].toInt() } - ctx.fromIntArray(shape, dtype, ints) as Tensor + ctx.wrapIntArray(shape, dtype, ints) as Tensor } DataType.INT8 -> { diff --git a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/StorageAwareSafeTensorsLoader.kt b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/StorageAwareSafeTensorsLoader.kt new file mode 100644 index 00000000..87686e9e --- /dev/null +++ b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/StorageAwareSafeTensorsLoader.kt @@ -0,0 +1,95 @@ +package sk.ainet.io.safetensors + +import sk.ainet.io.RandomAccessSource +import sk.ainet.lang.tensor.storage.TensorStorage + +/** + * SafeTensors loader that produces [TensorStorage] descriptors with + * zero-copy file-backed handles where possible. + * + * Unlike [SafeTensorsParametersLoader] which always decodes into typed arrays, + * this loader returns raw [TensorStorage] descriptors that can be: + * + * - **File-backed (zero-copy)**: When a file path is provided, tensors reference + * the original file via [BufferHandle.FileBacked][sk.ainet.lang.tensor.storage.BufferHandle.FileBacked]. + * No heap allocation occurs for the tensor data itself. + * + * - **Borrowed (single allocation)**: When no file path is available, tensor bytes + * are loaded into a single ByteArray and wrapped as + * [BufferHandle.Borrowed][sk.ainet.lang.tensor.storage.BufferHandle.Borrowed]. + * + * Usage: + * ```kotlin + * // Zero-copy: tensors reference the file directly + * val loader = StorageAwareSafeTensorsLoader(sourceProvider, filePath = "/models/model.safetensors") + * val tensors = loader.loadAll() + * // tensors[0].isFileBacked == true + * + * // Heap-loaded: tensors are borrowed byte arrays + * val loader = StorageAwareSafeTensorsLoader(sourceProvider) + * val tensors = loader.loadAll() + * // tensors[0].ownership == Ownership.BORROWED + * ``` + */ +public class StorageAwareSafeTensorsLoader( + private val sourceProvider: () -> RandomAccessSource, + private val filePath: String? = null, + private val onProgress: (current: Long, total: Long, tensorName: String?) -> Unit = { _, _, _ -> } +) { + /** + * Load all tensors as [TensorStorage] descriptors. + * + * When [filePath] is set, returns file-backed storage (zero-copy). + * Otherwise, returns borrowed storage with heap-loaded bytes. + * + * @return Map of tensor name to [TensorStorage] + */ + public fun loadAll(): Map { + val result = mutableMapOf() + StreamingSafeTensorsReader.open(sourceProvider()).use { reader -> + val tensors = reader.tensors + val total = tensors.size.toLong() + var current = 0L + + for (tensorInfo in tensors) { + val storage = if (filePath != null) { + reader.loadTensorStorageMapped(tensorInfo, filePath) + } else { + reader.loadTensorStorage(tensorInfo) + } + result[tensorInfo.name] = storage + current++ + onProgress(current, total, tensorInfo.name) + } + } + return result + } + + /** + * Load a single tensor by name as [TensorStorage]. + * + * @param name The tensor name + * @return [TensorStorage] descriptor + * @throws IllegalArgumentException if tensor not found + */ + public fun load(name: String): TensorStorage { + StreamingSafeTensorsReader.open(sourceProvider()).use { reader -> + val tensorInfo = reader.tensors.firstOrNull { it.name == name } + ?: throw IllegalArgumentException("Tensor not found: $name") + return if (filePath != null) { + reader.loadTensorStorageMapped(tensorInfo, filePath) + } else { + reader.loadTensorStorage(tensorInfo) + } + } + } + + /** + * List all tensor names and their metadata without loading data. + */ + public fun listTensors(): List { + StreamingSafeTensorsReader.open(sourceProvider()).use { reader -> + return reader.tensors.toList() + } + } +} diff --git a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/StreamingSafeTensorsReader.kt b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/StreamingSafeTensorsReader.kt index d14909b7..8cfc7b24 100644 --- a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/StreamingSafeTensorsReader.kt +++ b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/StreamingSafeTensorsReader.kt @@ -2,6 +2,8 @@ package sk.ainet.io.safetensors import sk.ainet.io.RandomAccessSource import sk.ainet.io.model.DataType +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.storage.* /** * Streaming SafeTensors reader that parses metadata without loading tensor data. @@ -86,6 +88,83 @@ public class StreamingSafeTensorsReader private constructor( return source.readAt(tensor.absoluteDataOffset, buffer, offset, tensor.sizeInBytes) } + // ========== TensorStorage Loading ========== + + /** + * Load a tensor as a [TensorStorage] descriptor with borrowed bytes. + */ + public fun loadTensorStorage(tensor: StreamingSafeTensorInfo): TensorStorage { + val bytes = loadTensorData(tensor) + val shape = Shape(*tensor.shape.map { it.toInt() }.toIntArray()) + return TensorStorage( + shape = shape, + logicalType = safeTensorsTypeToLogical(tensor.dataType), + encoding = safeTensorsTypeToEncoding(tensor.dataType), + buffer = BufferHandle.Borrowed(bytes, isMutable = false), + placement = Placement.CPU_HEAP + ) + } + + /** + * Load a tensor by name as a [TensorStorage] descriptor. + */ + public fun loadTensorStorage(name: String): TensorStorage { + val tensor = _tensors.firstOrNull { it.name == name } + ?: throw IllegalArgumentException("Tensor not found: $name") + return loadTensorStorage(tensor) + } + + /** + * Create a file-backed [TensorStorage] that references the tensor's bytes + * in the original file without loading them into heap. + * + * @param tensor The tensor info from [tensors] list + * @param filePath Path to the SafeTensors file + */ + public fun loadTensorStorageMapped(tensor: StreamingSafeTensorInfo, filePath: String): TensorStorage { + val shape = Shape(*tensor.shape.map { it.toInt() }.toIntArray()) + return TensorStorage( + shape = shape, + logicalType = safeTensorsTypeToLogical(tensor.dataType), + encoding = safeTensorsTypeToEncoding(tensor.dataType), + buffer = BufferHandle.FileBacked( + path = filePath, + fileOffset = tensor.absoluteDataOffset, + sizeInBytes = tensor.sizeInBytes.toLong() + ), + placement = Placement.MMAP_WEIGHTS + ) + } + + private fun safeTensorsTypeToLogical(type: DataType): LogicalDType = when (type) { + DataType.FLOAT32 -> LogicalDType.FLOAT32 + DataType.FLOAT64 -> LogicalDType.FLOAT64 + DataType.FLOAT16 -> LogicalDType.FLOAT16 + DataType.BFLOAT16 -> LogicalDType.BFLOAT16 + DataType.INT8 -> LogicalDType.INT8 + DataType.INT16 -> LogicalDType.INT16 + DataType.INT32 -> LogicalDType.INT32 + DataType.INT64 -> LogicalDType.INT64 + DataType.UINT8 -> LogicalDType.UINT8 + DataType.UINT16 -> LogicalDType.UINT16 + DataType.UINT32 -> LogicalDType.UINT32 + DataType.UINT64 -> LogicalDType.UINT64 + DataType.BOOL -> LogicalDType.UINT8 + else -> LogicalDType.INT8 // fallback for UNKNOWN + } + + private fun safeTensorsTypeToEncoding(type: DataType): TensorEncoding = when (type) { + DataType.FLOAT32 -> TensorEncoding.Dense(4) + DataType.FLOAT64 -> TensorEncoding.Dense(8) + DataType.FLOAT16 -> TensorEncoding.Dense(2) + DataType.BFLOAT16 -> TensorEncoding.Dense(2) + DataType.INT8, DataType.UINT8, DataType.BOOL -> TensorEncoding.Dense(1) + DataType.INT16, DataType.UINT16 -> TensorEncoding.Dense(2) + DataType.INT32, DataType.UINT32 -> TensorEncoding.Dense(4) + DataType.INT64, DataType.UINT64 -> TensorEncoding.Dense(8) + else -> TensorEncoding.Dense(1) + } + // ========== Parsing Implementation ========== private fun parse() { diff --git a/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/StorageAwareSafeTensorsLoaderTest.kt b/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/StorageAwareSafeTensorsLoaderTest.kt new file mode 100644 index 00000000..bdfa058a --- /dev/null +++ b/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/StorageAwareSafeTensorsLoaderTest.kt @@ -0,0 +1,204 @@ +package sk.ainet.io.safetensors + +import sk.ainet.io.RandomAccessSource +import sk.ainet.lang.tensor.storage.LogicalDType +import sk.ainet.lang.tensor.storage.MemoryDomain +import sk.ainet.lang.tensor.storage.Ownership +import sk.ainet.lang.tensor.storage.Residency +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Tests for [StorageAwareSafeTensorsLoader]. + */ +class StorageAwareSafeTensorsLoaderTest { + + /** + * Create a minimal valid SafeTensors file in memory. + * + * SafeTensors format: 8-byte header size (LE) + JSON header + tensor data. + */ + private fun createSafeTensorsBytes( + tensors: Map>> = mapOf( + "weight" to ("F32" to listOf(2L, 3L)) + ) + ): ByteArray { + // Build tensor data and header entries + val tensorEntries = mutableListOf() + val dataChunks = mutableListOf() + var offset = 0L + + for ((name, info) in tensors) { + val (dtype, shape) = info + val bytesPerElement = when (dtype) { + "F32" -> 4 + "F16" -> 2 + "I32" -> 4 + "I8" -> 1 + else -> 4 + } + val elementCount = if (shape.isEmpty()) 1L else shape.fold(1L) { a, b -> a * b } + val sizeInBytes = elementCount * bytesPerElement + val data = ByteArray(sizeInBytes.toInt()) + // Fill with recognizable pattern + for (i in data.indices) data[i] = (i % 256).toByte() + dataChunks.add(data) + + val shapeStr = shape.joinToString(",") + tensorEntries.add( + "\"$name\":{\"dtype\":\"$dtype\",\"shape\":[$shapeStr],\"data_offsets\":[$offset,${offset + sizeInBytes}]}" + ) + offset += sizeInBytes + } + + val headerJson = "{${tensorEntries.joinToString(",")}}" + val headerBytes = headerJson.encodeToByteArray() + val headerSize = headerBytes.size.toLong() + + // 8 bytes header size (LE) + header + data + val result = ByteArray(8 + headerBytes.size + dataChunks.sumOf { it.size }) + // Write header size as LE u64 + for (i in 0 until 8) { + result[i] = ((headerSize shr (i * 8)) and 0xFF).toByte() + } + headerBytes.copyInto(result, 8) + var dataOffset = 8 + headerBytes.size + for (chunk in dataChunks) { + chunk.copyInto(result, dataOffset) + dataOffset += chunk.size + } + return result + } + + private fun bytesAsSource(bytes: ByteArray): RandomAccessSource { + return object : RandomAccessSource { + override val size: Long get() = bytes.size.toLong() + + override fun readAt(offset: Long, length: Int): ByteArray { + return bytes.copyOfRange(offset.toInt(), offset.toInt() + length) + } + + override fun readAt(offset: Long, buffer: ByteArray, bufferOffset: Int, length: Int): Int { + bytes.copyInto(buffer, bufferOffset, offset.toInt(), offset.toInt() + length) + return length + } + + override fun close() {} + } + } + + // --- Heap-loaded (borrowed) mode --- + + @Test + fun loadAllBorrowed_returnsCorrectStorage() { + val fileBytes = createSafeTensorsBytes() + val loader = StorageAwareSafeTensorsLoader( + sourceProvider = { bytesAsSource(fileBytes) } + ) + + val tensors = loader.loadAll() + assertEquals(1, tensors.size) + assertTrue(tensors.containsKey("weight")) + + val storage = tensors["weight"]!! + assertEquals(LogicalDType.FLOAT32, storage.logicalType) + assertEquals(Ownership.BORROWED, storage.ownership) + assertFalse(storage.isFileBacked) + assertEquals(6L, storage.elementCount) // 2 * 3 + } + + // --- File-backed (zero-copy) mode --- + + @Test + fun loadAllMapped_returnsFileBackedStorage() { + val fileBytes = createSafeTensorsBytes() + val loader = StorageAwareSafeTensorsLoader( + sourceProvider = { bytesAsSource(fileBytes) }, + filePath = "/test/model.safetensors" + ) + + val tensors = loader.loadAll() + val storage = tensors["weight"]!! + assertTrue(storage.isFileBacked) + assertEquals(Ownership.FILE_BACKED, storage.ownership) + assertEquals(MemoryDomain.MMAP_FILE, storage.placement.domain) + assertEquals(Residency.PERSISTENT, storage.placement.residency) + assertFalse(storage.isMutable) + } + + // --- Single tensor load --- + + @Test + fun loadSingleTensor() { + val fileBytes = createSafeTensorsBytes( + mapOf( + "a" to ("F32" to listOf(4L)), + "b" to ("F32" to listOf(8L)) + ) + ) + val loader = StorageAwareSafeTensorsLoader( + sourceProvider = { bytesAsSource(fileBytes) } + ) + + val storageA = loader.load("a") + assertEquals(4L, storageA.elementCount) + + val storageB = loader.load("b") + assertEquals(8L, storageB.elementCount) + } + + @Test + fun loadMissingTensorThrows() { + val fileBytes = createSafeTensorsBytes() + val loader = StorageAwareSafeTensorsLoader( + sourceProvider = { bytesAsSource(fileBytes) } + ) + + assertFailsWith { + loader.load("nonexistent") + } + } + + // --- List tensors --- + + @Test + fun listTensorsReturnsMetadata() { + val fileBytes = createSafeTensorsBytes( + mapOf( + "embed" to ("F32" to listOf(100L, 64L)), + "bias" to ("F32" to listOf(64L)) + ) + ) + val loader = StorageAwareSafeTensorsLoader( + sourceProvider = { bytesAsSource(fileBytes) } + ) + + val infos = loader.listTensors() + assertEquals(2, infos.size) + assertEquals(setOf("embed", "bias"), infos.map { it.name }.toSet()) + } + + // --- Progress callback --- + + @Test + fun progressCallbackIsCalled() { + val fileBytes = createSafeTensorsBytes( + mapOf( + "a" to ("F32" to listOf(4L)), + "b" to ("F32" to listOf(8L)) + ) + ) + val progressCalls = mutableListOf>() + val loader = StorageAwareSafeTensorsLoader( + sourceProvider = { bytesAsSource(fileBytes) }, + onProgress = { current, total, name -> progressCalls.add(Triple(current, total, name)) } + ) + + loader.loadAll() + assertEquals(2, progressCalls.size) + assertEquals(2L, progressCalls[1].second) // total + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/context/ExecutionContext.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/context/ExecutionContext.kt index e1edaf58..1ae29071 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/context/ExecutionContext.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/context/ExecutionContext.kt @@ -6,6 +6,8 @@ import sk.ainet.lang.tensor.data.TensorData import sk.ainet.lang.tensor.data.TensorDataFactory import sk.ainet.lang.tensor.operators.OpsBoundTensor import sk.ainet.lang.tensor.ops.TensorOps +import sk.ainet.lang.tensor.storage.MemoryPlanner +import sk.ainet.lang.tensor.storage.MemoryTracker import sk.ainet.lang.types.DType import kotlin.reflect.KClass @@ -87,8 +89,52 @@ public interface ExecutionContext { return fromData(data, dtype) } + /** + * Wraps a FloatArray without copying (borrow semantics). + * The caller must ensure the array is not mutated while the tensor is in use. + */ + public fun wrapFloatArray( + shape: Shape, + dtype: KClass, + data: FloatArray + ): Tensor { + val tensorData = tensorDataFactory.wrapFloatArray(shape, dtype, data) + return fromData(tensorData, dtype) + } + + /** + * Wraps an IntArray without copying (borrow semantics). + * The caller must ensure the array is not mutated while the tensor is in use. + */ + public fun wrapIntArray( + shape: Shape, + dtype: KClass, + data: IntArray + ): Tensor { + val tensorData = tensorDataFactory.wrapIntArray(shape, dtype, data) + return fromData(tensorData, dtype) + } + + /** + * Wraps a ByteArray without copying (borrow semantics). + * The caller must ensure the array is not mutated while the tensor is in use. + */ + public fun wrapByteArray( + shape: Shape, + dtype: KClass, + data: ByteArray + ): Tensor { + val tensorData = tensorDataFactory.wrapByteArray(shape, dtype, data) + return fromData(tensorData, dtype) + } // runtime information public val memoryInfo: MemoryInfo public val executionStats: ExecutionStats + + /** Memory planner for resolving placement intents. Default: CPU-only. */ + public val memoryPlanner: MemoryPlanner get() = MemoryPlanner() + + /** Memory tracker for observability and copy tracing. Default: no-op (not tracking). */ + public val memoryTracker: MemoryTracker? get() = null } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/LayerScale.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/LayerScale.kt new file mode 100644 index 00000000..dd797571 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/LayerScale.kt @@ -0,0 +1,38 @@ +package sk.ainet.lang.nn + +import sk.ainet.context.ExecutionContext +import sk.ainet.lang.nn.topology.ModuleParameter +import sk.ainet.lang.nn.topology.ModuleParameters +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.types.DType + +/** + * Layer Scale: element-wise multiplication by a learnable per-channel scalar. + * + * Introduced in "Going deeper with Image Transformers" (CaiT). Used in + * vision transformers and audio codec decoders (Voxtral). + * + * @param dim Number of channels + * @param name Module name + * @param initScale Initial scale tensor (shape: [dim]), typically initialized to a small value (e.g. 0.01) + */ +public class LayerScale( + public val dim: Int, + override val name: String = "LayerScale", + initScale: Tensor? = null +) : Module(), ModuleParameters { + + override val params: List> = buildList { + if (initScale != null) { + add(ModuleParameter.WeightParameter("$name.gamma", initScale)) + } + } + + override val modules: List> = emptyList() + + override fun forward(input: Tensor, ctx: ExecutionContext): Tensor = + sk.ainet.lang.nn.hooks.withForwardHooks(ctx, this, input) { + if (params.isEmpty()) return@withForwardHooks input + ctx.ops.multiply(input, params[0].value) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/TransposedConv1d.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/TransposedConv1d.kt new file mode 100644 index 00000000..8f348ad8 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/TransposedConv1d.kt @@ -0,0 +1,93 @@ +package sk.ainet.lang.nn + +import sk.ainet.context.ExecutionContext +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.types.DType +import sk.ainet.lang.nn.topology.ModuleParameter +import sk.ainet.lang.nn.topology.ModuleParameters +import sk.ainet.lang.nn.topology.bias +import sk.ainet.lang.nn.topology.weights + +/** + * 1D Transposed Convolutional layer (also known as fractionally-strided convolution). + * + * This layer is commonly used for upsampling in generative models, audio synthesis + * (e.g. BigVGAN, Voxtral codec), and decoder architectures. + * + * @param inChannels Number of input channels + * @param outChannels Number of output channels + * @param kernelSize Size of the convolving kernel + * @param stride Stride of the convolution (default: 1) + * @param padding Padding added to both sides of the input (default: 0) + * @param outputPadding Additional size added to one side of the output (default: 0) + * @param dilation Spacing between kernel elements (default: 1) + * @param groups Number of blocked connections from input channels to output channels (default: 1) + * @param bias Whether to add a learnable bias to the output (default: true) + * @param name Name of the module + * @param initWeights Initial weights tensor + * @param initBias Initial bias tensor (if bias is true) + */ +public class TransposedConv1d( + public val inChannels: Int, + public val outChannels: Int, + public val kernelSize: Int, + public val stride: Int = 1, + public val padding: Int = 0, + public val outputPadding: Int = 0, + public val dilation: Int = 1, + public val groups: Int = 1, + public val bias: Boolean = true, + override val name: String = "TransposedConv1d", + initWeights: Tensor, + initBias: Tensor? = null, + public val trainable: Boolean = true +) : Module(), ModuleParameters { + + init { + require(inChannels > 0) { "inChannels must be positive" } + require(outChannels > 0) { "outChannels must be positive" } + require(kernelSize > 0) { "kernelSize must be positive" } + require(stride > 0) { "stride must be positive" } + require(padding >= 0) { "padding must be non-negative" } + require(outputPadding >= 0) { "outputPadding must be non-negative" } + require(outputPadding < stride) { "outputPadding must be less than stride" } + require(dilation > 0) { "dilation must be positive" } + require(groups > 0) { "groups must be positive" } + require(inChannels % groups == 0) { "inChannels must be divisible by groups" } + require(outChannels % groups == 0) { "outChannels must be divisible by groups" } + } + + override val params: List> = buildList { + add(ModuleParameter.WeightParameter("$name.weight", initWeights, trainable)) + if (bias && initBias != null) { + add(ModuleParameter.BiasParameter("$name.bias", initBias, trainable)) + } + } + + override val modules: List> + get() = emptyList() + + override fun forward(input: Tensor, ctx: ExecutionContext): Tensor = + sk.ainet.lang.nn.hooks.withForwardHooks(ctx, this, input) { + val weight = params.weights().value + val biasValue = if (bias) params.bias().value else null + + input.ops.convTranspose1d( + input = input, + weight = weight, + bias = biasValue, + stride = stride, + padding = padding, + outputPadding = outputPadding, + dilation = dilation, + groups = groups + ) + } + + /** + * Calculates the output size for a given input size and transposed convolution parameters. + */ + public fun outputSize(inputSize: Int): Int { + return (inputSize - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + outputPadding + 1 + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/activations/Snake.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/activations/Snake.kt new file mode 100644 index 00000000..0da726e0 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/activations/Snake.kt @@ -0,0 +1,51 @@ +package sk.ainet.lang.nn.activations + +import sk.ainet.context.ExecutionContext +import sk.ainet.lang.nn.Module +import sk.ainet.lang.nn.topology.ModuleParameter +import sk.ainet.lang.nn.topology.ModuleParameters +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.types.DType + +/** + * Snake activation function: f(x) = x + sin²(α * x) / α + * + * Used in audio synthesis models (BigVGAN, Voxtral codec) where it provides + * periodic inductive bias that helps model audio waveforms. + * + * @param channels Number of channels (for per-channel alpha) + * @param name Module name + * @param initAlpha Initial alpha parameter tensor (shape: [channels]) + */ +public class Snake( + public val channels: Int, + override val name: String = "Snake", + initAlpha: Tensor? = null +) : Module(), ModuleParameters { + + override val params: List> = buildList { + if (initAlpha != null) { + add(ModuleParameter.WeightParameter("$name.alpha", initAlpha)) + } + } + + override val modules: List> = emptyList() + + override fun forward(input: Tensor, ctx: ExecutionContext): Tensor = + sk.ainet.lang.nn.hooks.withForwardHooks(ctx, this, input) { + val ops = ctx.ops + if (params.isEmpty()) { + // alpha = 1: snake(x) = x + sin²(x) + val sinX = ops.sin(input) + val sin2X = ops.multiply(sinX, sinX) + ops.add(input, sin2X) + } else { + // snake(x) = x + sin²(α*x) / α + val alpha = params[0].value + val ax = ops.multiply(input, alpha) + val sinAx = ops.sin(ax) + val sin2Ax = ops.multiply(sinAx, sinAx) + ops.add(input, ops.divide(sin2Ax, alpha)) + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/CopyMaterializationStrategy.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/CopyMaterializationStrategy.kt index 81c46aa4..59588c7f 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/CopyMaterializationStrategy.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/CopyMaterializationStrategy.kt @@ -2,6 +2,7 @@ package sk.ainet.lang.tensor import sk.ainet.lang.tensor.data.TensorData import sk.ainet.lang.tensor.ops.TensorOps +import sk.ainet.lang.tensor.storage.ActiveMemoryTracker import sk.ainet.lang.types.DType import sk.ainet.lang.types.FP16 import sk.ainet.lang.types.FP32 @@ -66,6 +67,9 @@ public class CopyMaterializationStrategy : MaterializationStrategy // Copy all elements from the view to the new array copyViewElements(view, materializedData, viewShape) + // Record the copy for memory tracking + ActiveMemoryTracker.recordCopy("CopyMaterializationStrategy", viewVolume.toLong() * 4) + // Create and return the materialized tensor return createMaterializedTensor(view, materializedData, viewShape) } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/MaterializationExtensions.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/MaterializationExtensions.kt index 62e4c5d1..7bcf3243 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/MaterializationExtensions.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/MaterializationExtensions.kt @@ -146,4 +146,60 @@ public fun TensorView.estimateMaterializationCost( strategy: MaterializationStrategy ): Long { return strategy.estimateMemoryOverhead(this) +} + +// --- Explicit copy/alias operations (Phase 1b: memory-first) --- + +/** + * Explicitly copies this view into a standalone contiguous tensor. + * + * This is the same operation as [materialize] but with a name that makes + * the copy semantics unambiguous. Prefer this over [materialize] in new code. + * + * @return a new Tensor containing a copied, contiguous copy of this view's data + */ +public fun TensorView.copyMaterialize(): Tensor { + val strategy = CopyMaterializationStrategy() + return strategy.materialize(this) +} + +/** + * Realizes this view as an alias — returns a tensor that shares the parent's + * backing data when the view is a simple contiguous slice. + * + * If the view's [IndexMapper] reports that it is contiguous, this returns + * a lightweight tensor backed by the same data (zero-copy). Otherwise it + * falls back to [copyMaterialize]. + * + * @return a Tensor that either aliases the parent data or is a copy + */ +public fun TensorView.realizeAlias(): Tensor { + return if (indexMapping.isContiguous()) { + // Contiguous view: create a tensor that shares the parent's data + // but uses the view's shape. This is zero-copy. + AliasedTensor( + data = parentTensor.data, + ops = ops, + dtype = dtype, + gradState = gradState, + aliasedShape = viewShape + ) + } else { + // Non-contiguous view: must copy + copyMaterialize() + } +} + +/** + * Internal tensor wrapper that aliases parent data with a different shape. + * Used by [realizeAlias] for contiguous views. + */ +internal class AliasedTensor( + override val data: sk.ainet.lang.tensor.data.TensorData, + override val ops: sk.ainet.lang.tensor.ops.TensorOps, + override val dtype: kotlin.reflect.KClass, + override val gradState: GradState, + private val aliasedShape: Shape +) : Tensor { + override val shape: Shape get() = aliasedShape } \ No newline at end of file diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/DenseTensorDataFactory.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/DenseTensorDataFactory.kt index 4ca69305..8956fe9c 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/DenseTensorDataFactory.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/DenseTensorDataFactory.kt @@ -2,6 +2,7 @@ package sk.ainet.lang.tensor.data import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.data.dense.DenseByteTensorArray +import sk.ainet.lang.tensor.storage.ActiveMemoryTracker import sk.ainet.lang.types.DType import sk.ainet.lang.types.FP16 import sk.ainet.lang.types.FP32 @@ -143,11 +144,13 @@ public class DenseTensorDataFactory: TensorDataFactory { // Helper methods to create tensor data instances private fun createIntTensorData(shape: Shape, data: IntArray): TensorData { + ActiveMemoryTracker.recordCopy("DenseTensorDataFactory.createIntTensorData", data.size.toLong() * 4) return DenseIntArrayTensorData(shape, data.copyOf()) } @Suppress("UNCHECKED_CAST") private fun createFloatTensorData(shape: Shape, data: FloatArray, dtype: T): TensorData { + ActiveMemoryTracker.recordCopy("DenseTensorDataFactory.createFloatTensorData", data.size.toLong() * 4) return DenseFloatArrayTensorData(shape, data.copyOf()) as TensorData } @@ -647,4 +650,60 @@ public class DenseTensorDataFactory: TensorDataFactory { else -> throw IllegalArgumentException("fromByteArray only supports Int8 types with shape: $dtype") } } + + // --- Zero-copy wrap methods (borrow semantics) --- + + override fun wrapFloatArray( + shape: Shape, + dtype: KClass, + data: FloatArray + ): TensorData { + require(data.size == shape.volume) { + "Data size ${data.size} doesn't match shape volume ${shape.volume}" + } + @Suppress("UNCHECKED_CAST") + return when (dtype) { + FP32::class -> DenseFloatArrayTensorData(shape, data) as TensorData + FP16::class -> DenseFloatArrayTensorData(shape, data) as TensorData + else -> throw IllegalArgumentException("wrapFloatArray only supports floating point types: $dtype") + } + } + + override fun wrapIntArray( + shape: Shape, + dtype: KClass, + data: IntArray + ): TensorData { + require(data.size == shape.volume) { + "Data size ${data.size} doesn't match shape volume ${shape.volume}" + } + @Suppress("UNCHECKED_CAST") + return when (dtype) { + Int32::class -> DenseIntArrayTensorData(shape, data) as TensorData + else -> throw IllegalArgumentException("wrapIntArray only supports Int32 types: $dtype") + } + } + + override fun wrapByteArray( + shape: Shape, + dtype: KClass, + data: ByteArray + ): TensorData { + require(data.size == shape.volume) { + "Data size ${data.size} doesn't match shape volume ${shape.volume}" + } + @Suppress("UNCHECKED_CAST") + return when (dtype) { + Int8::class -> { + val denseArray = DenseByteTensorArray(shape, data) + class WrappedByteTensorData( + private val inner: DenseByteTensorArray + ) : TensorData, ItemsAccessor by inner { + override val shape: Shape = inner.shape + } + WrappedByteTensorData(denseArray) as TensorData + } + else -> throw IllegalArgumentException("wrapByteArray only supports Int8 types: $dtype") + } + } } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q4_KTensorData.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q4_KTensorData.kt index ee179497..1cdc60d1 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q4_KTensorData.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q4_KTensorData.kt @@ -1,6 +1,8 @@ package sk.ainet.lang.tensor.data import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.storage.PackedBlockStorage +import sk.ainet.lang.tensor.storage.TensorEncoding import sk.ainet.lang.types.DType /** @@ -75,7 +77,7 @@ public interface Q4_KTensorData : TensorData { public class Q4_KBlockTensorData( initialShape: Shape, private val data: ByteArray -) : Q4_KTensorData { +) : Q4_KTensorData, PackedBlockStorage { override val shape: Shape = Shape(initialShape.dimensions.copyOf()) private val strides: IntArray = shape.computeStrides() @@ -83,6 +85,28 @@ public class Q4_KBlockTensorData( override val blockCount: Int = (shape.volume + Q4_KTensorData.BLOCK_SIZE - 1) / Q4_KTensorData.BLOCK_SIZE + // PackedBlockStorage implementation + override val encoding: TensorEncoding get() = TensorEncoding.Q4_K + override val blockSize: Int get() = Q4_KTensorData.BLOCK_SIZE + + override fun dequantizeBlock(blockIdx: Int, output: FloatArray, outputOffset: Int) { + require(blockIdx in 0 until blockCount) { "Block index $blockIdx out of bounds (0..$blockCount)" } + for (subBlockIdx in 0 until Q4_KTensorData.SUB_BLOCKS_PER_BLOCK) { + val scale = getSubBlockScale(blockIdx, subBlockIdx) + val min = getSubBlockMin(blockIdx, subBlockIdx) + val elemsStart = subBlockIdx * Q4_KTensorData.SUB_BLOCK_SIZE + for (j in 0 until Q4_KTensorData.SUB_BLOCK_SIZE) { + val elementIdx = elemsStart + j + val outIdx = outputOffset + elementIdx + if (outIdx >= output.size) return + val globalIdx = blockIdx * Q4_KTensorData.BLOCK_SIZE + elementIdx + if (globalIdx >= shape.volume) return + val code = getCode(blockIdx, elementIdx) + output[outIdx] = code * scale + min + } + } + } + init { val requiredBytes = blockCount * Q4_KTensorData.BYTES_PER_BLOCK require(data.size >= requiredBytes) { diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q8_0TensorData.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q8_0TensorData.kt index f53ef9e6..673d8719 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q8_0TensorData.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q8_0TensorData.kt @@ -1,6 +1,8 @@ package sk.ainet.lang.tensor.data import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.storage.PackedBlockStorage +import sk.ainet.lang.tensor.storage.TensorEncoding import sk.ainet.lang.types.DType /** @@ -50,7 +52,7 @@ public interface Q8_0TensorData : TensorData { public class Q8_0BlockTensorData( initialShape: Shape, private val data: ByteArray -) : Q8_0TensorData { +) : Q8_0TensorData, PackedBlockStorage { override val shape: Shape = Shape(initialShape.dimensions.copyOf()) private val strides: IntArray = shape.computeStrides() @@ -58,6 +60,21 @@ public class Q8_0BlockTensorData( override val blockCount: Int = (shape.volume + Q8_0TensorData.BLOCK_SIZE - 1) / Q8_0TensorData.BLOCK_SIZE + // PackedBlockStorage implementation + override val encoding: TensorEncoding get() = TensorEncoding.Q8_0 + override val blockSize: Int get() = Q8_0TensorData.BLOCK_SIZE + + override fun dequantizeBlock(blockIdx: Int, output: FloatArray, outputOffset: Int) { + require(blockIdx in 0 until blockCount) { "Block index $blockIdx out of bounds (0..$blockCount)" } + val scale = getBlockScale(blockIdx) + val elemsInBlock = minOf(Q8_0TensorData.BLOCK_SIZE, shape.volume - blockIdx * Q8_0TensorData.BLOCK_SIZE) + for (i in 0 until elemsInBlock) { + val outIdx = outputOffset + i + if (outIdx >= output.size) return + output[outIdx] = getCode(blockIdx, i).toFloat() * scale + } + } + init { val requiredBytes = blockCount * Q8_0TensorData.BYTES_PER_BLOCK require(data.size >= requiredBytes) { diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TensorDataFactory.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TensorDataFactory.kt index 85f242fd..eed1712a 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TensorDataFactory.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TensorDataFactory.kt @@ -59,6 +59,39 @@ public interface TensorDataFactory { dtype: KClass, data: ByteArray ): TensorData + + /** + * Wraps a FloatArray without copying. The caller must ensure the array + * is not mutated while the returned TensorData is in use. + * Default implementation falls back to [fromFloatArray] (which copies). + */ + public fun wrapFloatArray( + shape: Shape, + dtype: KClass, + data: FloatArray + ): TensorData = fromFloatArray(shape, dtype, data) + + /** + * Wraps an IntArray without copying. The caller must ensure the array + * is not mutated while the returned TensorData is in use. + * Default implementation falls back to [fromIntArray] (which copies). + */ + public fun wrapIntArray( + shape: Shape, + dtype: KClass, + data: IntArray + ): TensorData = fromIntArray(shape, dtype, data) + + /** + * Wraps a ByteArray without copying. The caller must ensure the array + * is not mutated while the returned TensorData is in use. + * Default implementation falls back to [fromByteArray] (which copies). + */ + public fun wrapByteArray( + shape: Shape, + dtype: KClass, + data: ByteArray + ): TensorData = fromByteArray(shape, dtype, data) } /** diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TernaryTensorData.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TernaryTensorData.kt index 031681b3..63a5106c 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TernaryTensorData.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TernaryTensorData.kt @@ -1,6 +1,8 @@ package sk.ainet.lang.tensor.data import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.storage.PackedBlockStorage +import sk.ainet.lang.tensor.storage.TensorEncoding import sk.ainet.lang.types.Ternary /** @@ -45,12 +47,30 @@ public class Ternary2BitTensorData( initialShape: Shape, private val data: ByteArray, override val scale: Float = 1.0f -) : TernaryTensorData { +) : TernaryTensorData, PackedBlockStorage { override val shape: Shape = Shape(initialShape.dimensions.copyOf()) private val strides: IntArray = shape.computeStrides() override val packedData: ByteArray get() = data + // PackedBlockStorage — treat the whole tensor as a single block + override val encoding: TensorEncoding get() = TensorEncoding.TernaryPacked + override val blockSize: Int get() = shape.volume + override val blockCount: Int get() = 1 + + override fun dequantizeBlock(blockIdx: Int, output: FloatArray, outputOffset: Int) { + require(blockIdx == 0) { "Ternary has a single block, got index $blockIdx" } + for (i in 0 until shape.volume) { + val byteIndex = i / 4 + val bitOffset = (i % 4) * 2 + val encoded = (data[byteIndex].toInt() ushr bitOffset) and 0x03 + val ternary = encoded - 1 // decode: 0→-1, 1→0, 2→+1 + val outIdx = outputOffset + i + if (outIdx >= output.size) return + output[outIdx] = ternary.toFloat() * scale + } + } + init { val requiredBytes = (shape.volume + 3) / 4 // 4 values per byte require(data.size >= requiredBytes) { diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt index 901978af..3f730336 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt @@ -77,6 +77,20 @@ public interface TensorOps { groups: Int = 1 ): Tensor + // Transposed convolutional operations + public fun convTranspose1d( + input: Tensor, + weight: Tensor, + bias: Tensor? = null, + stride: Int = 1, + padding: Int = 0, + outputPadding: Int = 0, + dilation: Int = 1, + groups: Int = 1 + ): Tensor { + throw NotImplementedError("convTranspose1d not implemented by this TensorOps backend") + } + // Pooling operations @Diff public fun maxPool2d( @@ -225,6 +239,19 @@ public interface TensorOps { @ActivationDsl public fun expm1(tensor: Tensor): Tensor + // Trigonometric operations + public fun sin(tensor: Tensor): Tensor { + throw NotImplementedError("sin not implemented by this TensorOps backend") + } + + public fun cos(tensor: Tensor): Tensor { + throw NotImplementedError("cos not implemented by this TensorOps backend") + } + + public fun tanh(tensor: Tensor): Tensor { + throw NotImplementedError("tanh not implemented by this TensorOps backend") + } + /** * Scaled dot-product attention. * diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt index 11c8c2f3..e753ab31 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt @@ -191,6 +191,21 @@ public class VoidTensorOps : TensorOps { return VoidOpsTensor(resultData, input.dtype) } + override fun convTranspose1d( + input: Tensor, + weight: Tensor, + bias: Tensor?, + stride: Int, + padding: Int, + outputPadding: Int, + dilation: Int, + groups: Int + ): Tensor { + val resultShape = calculateConvTranspose1dShape(input.shape, weight.shape, stride, padding, outputPadding, dilation) + val resultData = dataFactory.zeros(resultShape, input.dtype) + return VoidOpsTensor(resultData, input.dtype) + } + override fun maxPool2d( input: Tensor, kernelSize: Pair, @@ -440,6 +455,21 @@ public class VoidTensorOps : TensorOps { return VoidOpsTensor(resultData, tensor.dtype) } + override fun sin(tensor: Tensor): Tensor { + val resultData = dataFactory.zeros(tensor.shape, tensor.dtype) + return VoidOpsTensor(resultData, tensor.dtype) + } + + override fun cos(tensor: Tensor): Tensor { + val resultData = dataFactory.zeros(tensor.shape, tensor.dtype) + return VoidOpsTensor(resultData, tensor.dtype) + } + + override fun tanh(tensor: Tensor): Tensor { + val resultData = dataFactory.zeros(tensor.shape, tensor.dtype) + return VoidOpsTensor(resultData, tensor.dtype) + } + override fun scaledDotProductAttention( query: Tensor, key: Tensor, @@ -776,6 +806,23 @@ public class VoidTensorOps : TensorOps { return Shape(batch, outChannels, outputDepth, outputHeight, outputWidth) } + /** + * Calculates the result shape for convTranspose1d operation. + * Input shape: (batch, in_channels, length) + * Weight shape: (in_channels, out_channels_per_group, kernel_size) + * Output shape: (batch, out_channels, out_length) + */ + private fun calculateConvTranspose1dShape( + inputShape: Shape, weightShape: Shape, stride: Int, padding: Int, outputPadding: Int, dilation: Int + ): Shape { + val batch = inputShape[0] + val outChannels = weightShape[1] + val inputLength = inputShape[2] + val kernelSize = weightShape[2] + val outputLength = (inputLength - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + outputPadding + 1 + return Shape(batch, outChannels, outputLength) + } + /** * Calculates the result shape for conv2d operation. * Input shape: (batch, in_channels, height, width) diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/BitPacker.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/BitPacker.kt new file mode 100644 index 00000000..648db4fc --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/BitPacker.kt @@ -0,0 +1,183 @@ +package sk.ainet.lang.tensor.ops.turboquant + +/** + * Bit-packing and unpacking for TurboQuant codes. + * + * Packs signed N-bit integer codes into compact byte arrays for storage. + * Supports 2, 3, 4, and 8-bit packing. Codes are stored as unsigned + * offsets (biased by 2^(bits-1)) to simplify packing. + * + * Packing is append-friendly: codes can be packed incrementally per token + * without re-packing the entire cache. + */ +public object BitPacker { + + /** + * Pack signed codes into a compact byte array. + * + * Codes are biased to unsigned range before packing: + * stored = code + 2^(bits-1) + * + * @param codes Signed codes (values in [-maxCode, maxCode]) + * @param bits Bits per code (2, 3, 4, or 8) + * @return Packed byte array + */ + public fun pack(codes: ByteArray, bits: Int): ByteArray { + require(bits in setOf(2, 3, 4, 8)) { "bits must be 2, 3, 4, or 8, got $bits" } + return when (bits) { + 2 -> pack2Bit(codes) + 3 -> pack3Bit(codes) + 4 -> pack4Bit(codes) + 8 -> pack8Bit(codes) + else -> error("unreachable") + } + } + + /** + * Unpack a byte array back to signed codes. + * + * @param packed Packed byte array + * @param count Number of codes to unpack + * @param bits Bits per code (2, 3, 4, or 8) + * @return Signed codes + */ + public fun unpack(packed: ByteArray, count: Int, bits: Int): ByteArray { + require(bits in setOf(2, 3, 4, 8)) { "bits must be 2, 3, 4, or 8, got $bits" } + return when (bits) { + 2 -> unpack2Bit(packed, count) + 3 -> unpack3Bit(packed, count) + 4 -> unpack4Bit(packed, count) + 8 -> unpack8Bit(packed, count) + else -> error("unreachable") + } + } + + /** + * Compute the byte size needed to pack [count] codes at [bits] per code. + */ + public fun packedSize(count: Int, bits: Int): Int { + return when (bits) { + 2 -> (count + 3) / 4 + 3 -> (count * 3 + 7) / 8 + 4 -> (count + 1) / 2 + 8 -> count + else -> throw IllegalArgumentException("bits must be 2, 3, 4, or 8") + } + } + + // ========== 2-bit packing ========== + // 4 codes per byte. Bias = 2 (range: [-1,1] → [1,3], stored as [0,3]) + + private fun pack2Bit(codes: ByteArray): ByteArray { + val bias = 2 // 2^(2-1) + val packed = ByteArray((codes.size + 3) / 4) + for (i in codes.indices) { + val unsigned = (codes[i].toInt() + bias) and 0x03 + val byteIdx = i / 4 + val shift = (i % 4) * 2 + packed[byteIdx] = (packed[byteIdx].toInt() or (unsigned shl shift)).toByte() + } + return packed + } + + private fun unpack2Bit(packed: ByteArray, count: Int): ByteArray { + val bias = 2 + val codes = ByteArray(count) + for (i in 0 until count) { + val byteIdx = i / 4 + val shift = (i % 4) * 2 + val unsigned = (packed[byteIdx].toInt() ushr shift) and 0x03 + codes[i] = (unsigned - bias).toByte() + } + return codes + } + + // ========== 3-bit packing ========== + // 8 codes per 3 bytes. Bias = 4 (range: [-3,3] → [1,7], stored as [0,7]) + + private fun pack3Bit(codes: ByteArray): ByteArray { + val bias = 4 // 2^(3-1) + val packed = ByteArray((codes.size * 3 + 7) / 8) + var bitPos = 0 + for (i in codes.indices) { + val unsigned = (codes[i].toInt() + bias) and 0x07 + val byteIdx = bitPos / 8 + val bitOffset = bitPos % 8 + packed[byteIdx] = (packed[byteIdx].toInt() or (unsigned shl bitOffset)).toByte() + // Handle overflow into next byte + if (bitOffset > 5) { + val overflow = unsigned ushr (8 - bitOffset) + if (byteIdx + 1 < packed.size) { + packed[byteIdx + 1] = (packed[byteIdx + 1].toInt() or overflow).toByte() + } + } + bitPos += 3 + } + return packed + } + + private fun unpack3Bit(packed: ByteArray, count: Int): ByteArray { + val bias = 4 + val codes = ByteArray(count) + var bitPos = 0 + for (i in 0 until count) { + val byteIdx = bitPos / 8 + val bitOffset = bitPos % 8 + var value = (packed[byteIdx].toInt() ushr bitOffset) and 0x07 + // Handle cross-byte boundary + if (bitOffset > 5 && byteIdx + 1 < packed.size) { + val bitsFromFirst = 8 - bitOffset + val remaining = 3 - bitsFromFirst + val fromNext = packed[byteIdx + 1].toInt() and ((1 shl remaining) - 1) + value = ((packed[byteIdx].toInt() ushr bitOffset) and ((1 shl bitsFromFirst) - 1)) or + (fromNext shl bitsFromFirst) + } + codes[i] = (value - bias).toByte() + bitPos += 3 + } + return codes + } + + // ========== 4-bit packing ========== + // 2 codes per byte. Bias = 8 (range: [-7,7] → [1,15], stored as [0,15]) + + private fun pack4Bit(codes: ByteArray): ByteArray { + val bias = 8 // 2^(4-1) + val packed = ByteArray((codes.size + 1) / 2) + for (i in codes.indices) { + val unsigned = (codes[i].toInt() + bias) and 0x0F + val byteIdx = i / 2 + if (i % 2 == 0) { + packed[byteIdx] = (packed[byteIdx].toInt() or unsigned).toByte() + } else { + packed[byteIdx] = (packed[byteIdx].toInt() or (unsigned shl 4)).toByte() + } + } + return packed + } + + private fun unpack4Bit(packed: ByteArray, count: Int): ByteArray { + val bias = 8 + val codes = ByteArray(count) + for (i in 0 until count) { + val byteIdx = i / 2 + val unsigned = if (i % 2 == 0) { + packed[byteIdx].toInt() and 0x0F + } else { + (packed[byteIdx].toInt() ushr 4) and 0x0F + } + codes[i] = (unsigned - bias).toByte() + } + return codes + } + + // ========== 8-bit packing ========== + // 1:1 mapping, codes are already bytes + + private fun pack8Bit(codes: ByteArray): ByteArray = codes.copyOf() + + private fun unpack8Bit(packed: ByteArray, count: Int): ByteArray { + return if (packed.size == count) packed.copyOf() + else packed.copyOfRange(0, count) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/QjlResidual.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/QjlResidual.kt new file mode 100644 index 00000000..08ccfffd --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/QjlResidual.kt @@ -0,0 +1,133 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import kotlin.math.sqrt +import kotlin.random.Random + +/** + * QJL (Quantized Johnson-Lindenstrauss) residual stage for TurboQuant. + * + * After scalar quantization, there is a residual error: + * residual = original_rotated - dequantized + * + * The QJL stage projects this residual onto a random low-dimensional + * subspace and quantizes the projection. This preserves inner-product + * accuracy (Johnson-Lindenstrauss property) at the cost of additional + * storage. + * + * This stage is used only by the [TurboQuantPolarQjl] variant. + * The [TurboQuantPolar] variant omits it for simplicity and speed. + */ +public object QjlResidual { + + /** + * Encode a residual vector using QJL projection. + * + * 1. Project residual onto random directions (seeded) + * 2. Quantize projections to [residualBits] per component + * + * @param residual Quantization residual (original - dequantized) + * @param residualBits Bits per residual component (1-4) + * @param seed Seed for deterministic projection + * @return Encoded residual (packed bytes + scale) + */ + public fun encode(residual: FloatArray, residualBits: Int, seed: Int): EncodedResidual { + require(residualBits in 1..4) { "residualBits must be 1-4, got $residualBits" } + + val dim = residual.size + // Project onto dim random directions (same dimensionality, quantized) + // For 1-bit: just store sign of random projection + // For 2-4 bits: scalar-quantize the projected values + val rng = Random(seed) + + if (residualBits == 1) { + // 1-bit QJL: store sign(residual[i] * randomSign[i]) + // Equivalent to random sign-flip + sign extraction + val packed = ByteArray((dim + 7) / 8) + var scale = 0f + for (i in 0 until dim) { + scale += residual[i] * residual[i] + } + scale = sqrt(scale / dim) + + for (i in 0 until dim) { + val sign = if (rng.nextBoolean()) 1f else -1f + val bit = if (residual[i] * sign >= 0f) 1 else 0 + packed[i / 8] = (packed[i / 8].toInt() or (bit shl (i % 8))).toByte() + } + return EncodedResidual(packed, scale, residualBits, dim) + } else { + // Multi-bit: scalar quantize the residual directly + val quantized = ScalarQuantizer.quantize(residual, residualBits) + val packed = BitPacker.pack(quantized.codes, residualBits) + // Use the mean scale as a single scale factor + val meanScale = if (quantized.scales.isNotEmpty()) { + quantized.scales.sum() / quantized.scales.size + } else 0f + return EncodedResidual(packed, meanScale, residualBits, dim) + } + } + + /** + * Decode a QJL residual and add it to the base reconstruction. + * + * @param encoded The encoded residual + * @param output Array to add the decoded residual into (modified in place) + * @param seed Same seed used during [encode] + */ + public fun decode(encoded: EncodedResidual, output: FloatArray, seed: Int) { + val dim = encoded.elementCount + require(output.size >= dim) { "Output size ${output.size} < dim $dim" } + + val rng = Random(seed) + + if (encoded.residualBits == 1) { + // 1-bit: reconstruct as ±scale * randomSign + val scale = encoded.scale + for (i in 0 until dim) { + val sign = if (rng.nextBoolean()) 1f else -1f + val bit = (encoded.packed[i / 8].toInt() ushr (i % 8)) and 1 + val value = if (bit == 1) scale else -scale + output[i] += value * sign + } + } else { + // Multi-bit: unpack and dequantize, then add + val codes = BitPacker.unpack(encoded.packed, dim, encoded.residualBits) + for (i in 0 until dim) { + output[i] += codes[i].toFloat() * encoded.scale + } + } + } +} + +/** + * Encoded QJL residual data. + */ +public data class EncodedResidual( + /** Packed residual bits. */ + val packed: ByteArray, + /** Scale factor for reconstruction. */ + val scale: Float, + /** Bits per residual component. */ + val residualBits: Int, + /** Number of elements. */ + val elementCount: Int +) { + val packedSizeBytes: Int get() = packed.size + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is EncodedResidual) return false + return scale == other.scale && + residualBits == other.residualBits && + elementCount == other.elementCount && + packed.contentEquals(other.packed) + } + + override fun hashCode(): Int { + var result = packed.contentHashCode() + result = 31 * result + scale.hashCode() + result = 31 * result + residualBits + result = 31 * result + elementCount + return result + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/RandomRotation.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/RandomRotation.kt new file mode 100644 index 00000000..66e379e2 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/RandomRotation.kt @@ -0,0 +1,132 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import kotlin.math.PI +import kotlin.math.cos +import kotlin.math.sin +import kotlin.math.sqrt +import kotlin.random.Random + +/** + * Random rotation for TurboQuant encoding. + * + * TurboQuant uses random orthogonal rotations to spread quantization error + * uniformly across dimensions before scalar quantization. This is the first + * stage of the TurboQuant pipeline. + * + * The rotation is **deterministic** given a seed, so the same rotation can + * be reproduced during decoding without storing the full rotation matrix. + * + * Implementation uses fast random Hadamard-like rotations (random sign flips + * + structured permutation) rather than full O(d^2) matrix multiplication. + * This gives O(d log d) rotation cost. + */ +public object RandomRotation { + + /** + * Apply a seeded random rotation to a vector in-place. + * + * Uses the "random sign flip + fast Walsh-Hadamard transform" approach: + * 1. Apply random +-1 sign flips (seeded) + * 2. Apply normalized Walsh-Hadamard transform + * + * This produces a near-uniform rotation in O(d log d) time. + * + * @param vector Input/output vector (modified in place) + * @param seed Deterministic seed for reproducibility + */ + public fun rotate(vector: FloatArray, seed: Int) { + randomSignFlip(vector, seed) + walshHadamard(vector) + } + + /** + * Apply the inverse rotation to recover the original vector. + * + * Since sign flips and Hadamard are both self-inverse (up to normalization), + * the inverse is the same operations in reverse order. + * + * @param vector Input/output vector (modified in place) + * @param seed Same seed used during [rotate] + */ + public fun inverseRotate(vector: FloatArray, seed: Int) { + walshHadamard(vector) + randomSignFlip(vector, seed) + } + + /** + * Apply random +-1 sign flips to each element. + * + * This is equivalent to multiplying by a diagonal matrix D where + * D_ii ∈ {-1, +1} drawn from a seeded PRNG. + */ + internal fun randomSignFlip(vector: FloatArray, seed: Int) { + val rng = Random(seed) + for (i in vector.indices) { + if (rng.nextBoolean()) { + vector[i] = -vector[i] + } + } + } + + /** + * In-place normalized Walsh-Hadamard transform. + * + * The WHT is an orthogonal transform (when normalized by 1/sqrt(n)) + * that can be computed in O(n log n) time. It spreads information + * uniformly across all dimensions. + * + * For non-power-of-2 dimensions, the vector is conceptually zero-padded + * to the next power of 2, transformed, then truncated. In practice we + * handle this by processing only up to the largest power of 2 <= n and + * leaving remaining elements with just the sign flip. + */ + internal fun walshHadamard(vector: FloatArray) { + val n = vector.size + if (n <= 1) return + + // Find largest power of 2 <= n + var len = 1 + while (len * 2 <= n) len *= 2 + + // Iterative WHT (butterfly) + var h = 1 + while (h < len) { + var i = 0 + while (i < len) { + for (j in i until i + h) { + val x = vector[j] + val y = vector[j + h] + vector[j] = x + y + vector[j + h] = x - y + } + i += h * 2 + } + h *= 2 + } + + // Normalize by 1/sqrt(len) to make the transform orthogonal + val norm = 1.0f / sqrt(len.toFloat()) + for (i in 0 until len) { + vector[i] *= norm + } + } + + /** + * Generate a rotation seed for a given (layer, head, position) triple. + * + * Uses a simple hash combining function to produce deterministic seeds + * that are well-distributed across the seed space. + */ + public fun seedFor(layer: Int, head: Int, position: Int): Int { + var h = layer + h = h * 31 + head + h = h * 31 + position + // Mix bits (MurmurHash3 finalizer) + h = h xor (h ushr 16) + h *= 0x85ebca6b.toInt() + h = h xor (h ushr 13) + h *= 0xc2b2ae35.toInt() + h = h xor (h ushr 16) + return h + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/ScalarQuantizer.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/ScalarQuantizer.kt new file mode 100644 index 00000000..07552248 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/ScalarQuantizer.kt @@ -0,0 +1,148 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import kotlin.math.abs +import kotlin.math.max +import kotlin.math.min +import kotlin.math.round + +/** + * Scalar quantization and codebook lookup for TurboQuant. + * + * After random rotation spreads quantization error uniformly, scalar + * quantization maps each element independently to an N-bit integer code. + * This is simpler and faster than vector quantization while achieving + * good quality thanks to the rotation preprocessing. + * + * The quantizer uses a **uniform symmetric** scheme: + * - Compute per-group scale = max(abs(group)) / ((2^(bits-1)) - 1) + * - Quantize: code = round(value / scale), clamped to [-2^(bits-1)+1, 2^(bits-1)-1] + * - Dequantize: value ≈ code * scale + * + * Groups of 32 elements share a single FP16 scale factor. + */ +public object ScalarQuantizer { + + /** Number of elements per quantization group. */ + public const val GROUP_SIZE: Int = 32 + + /** + * Quantize a float vector to integer codes with per-group scales. + * + * @param input Float values (already rotated) + * @param bits Bits per code (2, 3, 4, or 8) + * @return [QuantizedVector] containing codes and scales + */ + public fun quantize(input: FloatArray, bits: Int): QuantizedVector { + require(bits in setOf(2, 3, 4, 8)) { "bits must be 2, 3, 4, or 8, got $bits" } + + val maxCode = (1 shl (bits - 1)) - 1 // e.g., 7 for 4-bit, 1 for 2-bit + val numGroups = (input.size + GROUP_SIZE - 1) / GROUP_SIZE + val scales = FloatArray(numGroups) + val codes = ByteArray(input.size) + + for (g in 0 until numGroups) { + val start = g * GROUP_SIZE + val end = min(start + GROUP_SIZE, input.size) + + // Find max absolute value in group + var absMax = 0f + for (i in start until end) { + absMax = max(absMax, abs(input[i])) + } + + // Compute scale (avoid division by zero) + val scale = if (absMax > 0f) absMax / maxCode else 0f + scales[g] = scale + + // Quantize each element + if (scale > 0f) { + val invScale = 1f / scale + for (i in start until end) { + val q = round(input[i] * invScale).toInt() + codes[i] = q.coerceIn(-maxCode, maxCode).toByte() + } + } + // else: codes stay 0 + } + + return QuantizedVector(codes, scales, bits) + } + + /** + * Dequantize codes back to float values using stored scales. + * + * @param quantized The quantized codes and scales + * @return Reconstructed float values + */ + public fun dequantize(quantized: QuantizedVector): FloatArray { + val output = FloatArray(quantized.codes.size) + val numGroups = quantized.scales.size + + for (g in 0 until numGroups) { + val start = g * GROUP_SIZE + val end = min(start + GROUP_SIZE, output.size) + val scale = quantized.scales[g] + + for (i in start until end) { + output[i] = quantized.codes[i].toFloat() * scale + } + } + + return output + } + + /** + * Dequantize codes in-place into an existing output array. + * + * @param codes Quantized codes + * @param scales Per-group scale factors + * @param output Destination array + * @param offset Starting offset in output + */ + public fun dequantizeInto( + codes: ByteArray, + scales: FloatArray, + output: FloatArray, + offset: Int = 0 + ) { + for (g in scales.indices) { + val start = g * GROUP_SIZE + val end = min(start + GROUP_SIZE, codes.size) + val scale = scales[g] + + for (i in start until end) { + output[offset + i] = codes[i].toFloat() * scale + } + } + } +} + +/** + * Result of scalar quantization: integer codes + per-group scales. + */ +public data class QuantizedVector( + /** Signed integer codes, one per element. Values in [-maxCode, maxCode]. */ + val codes: ByteArray, + /** Per-group scale factors (one per GROUP_SIZE elements). */ + val scales: FloatArray, + /** Number of bits per code. */ + val bits: Int +) { + val elementCount: Int get() = codes.size + val numGroups: Int get() = scales.size + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is QuantizedVector) return false + return bits == other.bits && + codes.contentEquals(other.codes) && + scales.contentEquals(other.scales) + } + + override fun hashCode(): Int { + var result = codes.contentHashCode() + result = 31 * result + scales.contentHashCode() + result = 31 * result + bits + return result + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantCodec.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantCodec.kt new file mode 100644 index 00000000..b3d6e550 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantCodec.kt @@ -0,0 +1,175 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import sk.ainet.lang.tensor.storage.TensorEncoding + +/** + * End-to-end TurboQuant encode/decode codec. + * + * Wires together the full TurboQuant pipeline: + * 1. Random rotation (spread quantization error) + * 2. Scalar quantization (map to N-bit codes) + * 3. Optional QJL residual (preserve inner-product accuracy) + * 4. Bit-packing (compact storage) + * + * Supports two variants: + * - **PolarOnly**: Steps 1-2-4 (fast, backend-friendly) + * - **PolarPlusQjl**: Steps 1-2-3-4 (higher accuracy) + * + * Usage: + * ```kotlin + * val encoded = TurboQuantCodec.encode(vector, config) + * val decoded = TurboQuantCodec.decode(encoded) + * ``` + */ +public object TurboQuantCodec { + + /** + * Encode a float vector using TurboQuant. + * + * @param input Raw float vector (e.g., a K or V projection for one head) + * @param config Encoding configuration + * @return Encoded block ready for storage + */ + public fun encode(input: FloatArray, config: TurboQuantConfig): TurboQuantBlock { + // 1. Random rotation + val rotated = input.copyOf() + RandomRotation.rotate(rotated, config.seed) + + // 2. Scalar quantization + val quantized = ScalarQuantizer.quantize(rotated, config.bits) + + // 3. Bit-packing + val packedCodes = BitPacker.pack(quantized.codes, config.bits) + + // 4. Optional QJL residual + val residual = if (config.useQjl) { + val dequantized = ScalarQuantizer.dequantize(quantized) + val residualVec = FloatArray(input.size) { rotated[it] - dequantized[it] } + QjlResidual.encode(residualVec, config.residualBits, config.seed + 1) + } else null + + return TurboQuantBlock( + packedCodes = packedCodes, + scales = quantized.scales, + seed = config.seed, + bits = config.bits, + elementCount = input.size, + residual = residual + ) + } + + /** + * Decode a TurboQuant block back to float values. + * + * @param block The encoded block + * @return Reconstructed float vector + */ + public fun decode(block: TurboQuantBlock): FloatArray { + // 1. Unpack codes + val codes = BitPacker.unpack(block.packedCodes, block.elementCount, block.bits) + + // 2. Dequantize + val output = FloatArray(block.elementCount) + ScalarQuantizer.dequantizeInto(codes, block.scales, output) + + // 3. Add QJL residual if present + if (block.residual != null) { + QjlResidual.decode(block.residual, output, block.seed + 1) + } + + // 4. Inverse rotation + RandomRotation.inverseRotate(output, block.seed) + + return output + } + + /** + * Compute the byte size of an encoded block. + */ + public fun encodedSize(elementCount: Int, config: TurboQuantConfig): Int { + val codeBytes = BitPacker.packedSize(elementCount, config.bits) + val scaleBytes = ((elementCount + ScalarQuantizer.GROUP_SIZE - 1) / ScalarQuantizer.GROUP_SIZE) * 4 + val seedBytes = 4 + val residualBytes = if (config.useQjl) { + BitPacker.packedSize(elementCount, config.residualBits) + 4 // packed + scale + } else 0 + return codeBytes + scaleBytes + seedBytes + residualBytes + } +} + +/** + * Configuration for TurboQuant encoding. + */ +public data class TurboQuantConfig( + /** Bits per quantized code (2, 3, 4, or 8). */ + val bits: Int = 4, + /** Whether to use QJL residual stage. */ + val useQjl: Boolean = false, + /** Bits for QJL residual (1-4, only used if [useQjl] is true). */ + val residualBits: Int = 1, + /** Deterministic seed for random rotation. */ + val seed: Int = 0 +) { + init { + require(bits in setOf(2, 3, 4, 8)) { "bits must be 2, 3, 4, or 8, got $bits" } + if (useQjl) { + require(residualBits in 1..4) { "residualBits must be 1-4, got $residualBits" } + } + } + + /** Create a config for PolarOnly variant. */ + public companion object { + public fun polarOnly(bits: Int = 4, seed: Int = 0): TurboQuantConfig = + TurboQuantConfig(bits = bits, useQjl = false, seed = seed) + + public fun polarPlusQjl(bits: Int = 4, residualBits: Int = 1, seed: Int = 0): TurboQuantConfig = + TurboQuantConfig(bits = bits, useQjl = true, residualBits = residualBits, seed = seed) + } +} + +/** + * A single TurboQuant-encoded block. + * + * Contains all data needed to reconstruct the original float vector. + */ +public data class TurboQuantBlock( + /** Bit-packed quantization codes. */ + val packedCodes: ByteArray, + /** Per-group scale factors. */ + val scales: FloatArray, + /** Rotation seed for reproducibility. */ + val seed: Int, + /** Bits per code. */ + val bits: Int, + /** Number of logical float elements. */ + val elementCount: Int, + /** Optional QJL residual (null for PolarOnly). */ + val residual: EncodedResidual? = null +) { + /** Total bytes used by this block. */ + val sizeInBytes: Int + get() = packedCodes.size + scales.size * 4 + 4 + (residual?.packedSizeBytes ?: 0) + + val isPolarOnly: Boolean get() = residual == null + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is TurboQuantBlock) return false + return seed == other.seed && + bits == other.bits && + elementCount == other.elementCount && + packedCodes.contentEquals(other.packedCodes) && + scales.contentEquals(other.scales) && + residual == other.residual + } + + override fun hashCode(): Int { + var result = packedCodes.contentHashCode() + result = 31 * result + scales.contentHashCode() + result = 31 * result + seed + result = 31 * result + bits + result = 31 * result + elementCount + result = 31 * result + (residual?.hashCode() ?: 0) + return result + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantPresets.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantPresets.kt new file mode 100644 index 00000000..7be97d2f --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantPresets.kt @@ -0,0 +1,151 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import sk.ainet.lang.tensor.storage.KvCacheConfig +import sk.ainet.lang.tensor.storage.Placement +import sk.ainet.lang.tensor.storage.Residency +import sk.ainet.lang.tensor.storage.TensorEncoding + +/** + * Named preset configurations for TurboQuant KV-cache compression. + * + * Presets reflect the practical observation that key precision is often + * more quality-sensitive than value precision. + * + * Available presets: + * - **safe-lowbit**: Q8_0 keys + TurboQuant-4 values (conservative) + * - **balanced**: TurboQuant-4 keys + TurboQuant-4 values + * - **experimental-max**: TurboQuant-3 keys + TurboQuant-3 values (aggressive) + */ +public object TurboQuantPresets { + + /** + * Safe low-bit preset: Q8_0 for keys, TurboQuant-4 for values. + * + * Keys stay at 8-bit for quality preservation; values are compressed + * to 4-bit TurboQuant. Good for production use where key accuracy + * matters more than value accuracy. + */ + public fun safeLowbit( + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int + ): TurboQuantPreset = TurboQuantPreset( + name = "safe-lowbit", + cacheConfig = KvCacheConfig( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = maxSeqLen, + keyEncoding = TensorEncoding.Q8_0, + valueEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 4), + placement = Placement.CPU_HEAP.copy(residency = Residency.PERSISTENT) + ), + keyQuantConfig = null, // Q8_0 uses standard quantization, not TurboQuant + valueQuantConfig = TurboQuantConfig.polarOnly(bits = 4) + ) + + /** + * Balanced preset: TurboQuant-4 for both keys and values. + * + * Symmetric 4-bit compression for both K and V. Good balance + * between compression ratio and quality. + */ + public fun balanced( + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int + ): TurboQuantPreset = TurboQuantPreset( + name = "balanced", + cacheConfig = KvCacheConfig( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = maxSeqLen, + keyEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 4), + valueEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 4), + placement = Placement.CPU_HEAP.copy(residency = Residency.PERSISTENT) + ), + keyQuantConfig = TurboQuantConfig.polarOnly(bits = 4), + valueQuantConfig = TurboQuantConfig.polarOnly(bits = 4) + ) + + /** + * Experimental maximum compression: TurboQuant-3 for both K and V. + * + * Aggressive 3-bit compression. Use with caution — may degrade quality + * for some models. Best suited for long-context scenarios where memory + * is the primary constraint. + */ + public fun experimentalMax( + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int + ): TurboQuantPreset = TurboQuantPreset( + name = "experimental-max", + cacheConfig = KvCacheConfig( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = maxSeqLen, + keyEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 3), + valueEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 3), + placement = Placement.CPU_HEAP.copy(residency = Residency.PERSISTENT) + ), + keyQuantConfig = TurboQuantConfig.polarOnly(bits = 3), + valueQuantConfig = TurboQuantConfig.polarOnly(bits = 3) + ) + + /** + * List all available preset names. + */ + public val availablePresets: List = listOf("safe-lowbit", "balanced", "experimental-max") + + /** + * Look up a preset by name and apply model dimensions. + * + * This is the primary entry point for skainet-transformers and other + * consumers that want to enable TurboQuant with a single call. + * + * Example: + * ```kotlin + * val preset = TurboQuantPresets.forModel("balanced", numLayers=32, numHeads=32, headDim=128, maxSeqLen=4096) + * val cache = KvCacheStore.fromPreset(preset) + * ``` + * + * @param preset One of "safe-lowbit", "balanced", "experimental-max" + * @param numLayers Number of transformer layers + * @param numHeads Number of KV heads per layer + * @param headDim Dimension per head + * @param maxSeqLen Maximum sequence length + * @throws IllegalArgumentException if preset name is unknown + */ + public fun forModel( + preset: String, + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int + ): TurboQuantPreset = when (preset) { + "safe-lowbit" -> safeLowbit(numLayers, numHeads, headDim, maxSeqLen) + "balanced" -> balanced(numLayers, numHeads, headDim, maxSeqLen) + "experimental-max" -> experimentalMax(numLayers, numHeads, headDim, maxSeqLen) + else -> throw IllegalArgumentException( + "Unknown TurboQuant preset: '$preset'. Available: $availablePresets" + ) + } +} + +/** + * A named TurboQuant preset with all configuration needed to create a cache. + */ +public data class TurboQuantPreset( + val name: String, + val cacheConfig: KvCacheConfig, + /** TurboQuant config for keys, or null if keys use non-TurboQuant encoding. */ + val keyQuantConfig: TurboQuantConfig?, + /** TurboQuant config for values, or null if values use non-TurboQuant encoding. */ + val valueQuantConfig: TurboQuantConfig? +) diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantUsage.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantUsage.kt new file mode 100644 index 00000000..5ff86482 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantUsage.kt @@ -0,0 +1,176 @@ +@file:Suppress("unused") + +package sk.ainet.lang.tensor.ops.turboquant + +import sk.ainet.lang.tensor.storage.* + +/** + * TurboQuant integration guide for skainet-transformers. + * + * TurboQuant compresses the KV cache at **runtime** — no model retraining + * or weight re-quantization needed. Any model (LLaMA, Mistral, Gemma, + * Qwen, etc.) benefits immediately. + * + * ## What TurboQuant does + * + * During autoregressive inference, the KV cache grows linearly with + * sequence length and dominates memory usage. TurboQuant compresses + * K/V projections on write and decompresses on read: + * + * - **4-bit (balanced)**: ~8x compression vs FP32 + * - **3-bit (experimental-max)**: ~10x compression + * - **safe-lowbit**: Q8_0 keys + 4-bit values (conservative) + * + * ## Quick start + * + * ### 1. One-line cache creation + * + * ```kotlin + * // Replace your existing KV cache with TurboQuant: + * val cache = KvCacheStore.turboQuant( + * preset = "balanced", + * numLayers = 32, + * numHeads = 32, + * headDim = 128, + * maxSeqLen = 4096 + * ) + * ``` + * + * ### 2. Use in attention layer + * + * ```kotlin + * class MultiHeadAttention( + * val numHeads: Int, + * val headDim: Int, + * val cache: KvCacheStore + * ) { + * private val bridge = CompressedKvAttention(cache) + * + * fun forward(query: FloatArray, key: FloatArray, value: FloatArray, layer: Int): FloatArray { + * // Store K/V (compressed automatically) + * bridge.storeKeyValue(layer, key, value) + * + * // Read for attention (decompressed automatically) + * val cachedKeys = bridge.loadKeysForAttention(layer) + * val cachedValues = bridge.loadValuesForAttention(layer) + * + * // Pass to scaledDotProductAttention as usual + * return computeAttention(query, cachedKeys, cachedValues) + * } + * } + * ``` + * + * ### 3. Annotate layers (optional) + * + * ```kotlin + * @KvCache(preset = "balanced") + * class SelfAttention(...) { ... } + * + * // Resolve at model init: + * val cache = KvCacheAnnotationResolver.resolve( + * preset = "balanced", + * numLayers = config.numLayers, + * numHeads = config.numKVHeads, + * headDim = config.headDim, + * maxSeqLen = config.maxSeqLen + * ) + * ``` + * + * ### 4. Monitor compression + * + * ```kotlin + * val report = cache.memoryReport() + * println("Compression: ${report.compressionRatio}x") + * println("KV cache: ${report.totalPhysicalBytes / 1024 / 1024} MB") + * println("Utilization: ${(report.utilizationRatio * 100).toInt()}%") + * ``` + * + * ## Preset selection guide + * + * | Preset | Key bits | Value bits | Compression | Quality | Use case | + * |--------|----------|------------|-------------|---------|----------| + * | safe-lowbit | 8 (Q8_0) | 4 (TQ) | ~4-6x | Best | Production, quality-sensitive | + * | balanced | 4 (TQ) | 4 (TQ) | ~8x | Good | General purpose, long context | + * | experimental-max | 3 (TQ) | 3 (TQ) | ~10x | Fair | Memory-constrained, very long context | + * + * ## Model compatibility + * + * TurboQuant works with **any model** regardless of: + * - Weight quantization format (GGUF Q4_K, Q8_0, FP16, etc.) + * - Architecture (LLaMA, Mistral, Gemma, Qwen, BERT) + * - Model size (1B to 70B+) + * - Age (works with older models too) + * + * The model weights are unchanged — only the KV cache storage is compressed. + */ +public object TurboQuantUsage { + + /** + * Example: Create a balanced TurboQuant cache for a LLaMA-style model. + * + * This is a compilable reference showing the full integration pattern. + */ + public fun exampleLlamaCache(): KvCacheStore { + // LLaMA-7B dimensions + val numLayers = 32 + val numHeads = 32 // or numKVHeads for GQA models + val headDim = 128 + val maxSeqLen = 4096 + + // One-line creation: + return KvCacheStore.turboQuant("balanced", numLayers, numHeads, headDim, maxSeqLen) + } + + /** + * Example: Asymmetric K/V compression (8-bit keys, 4-bit values). + */ + public fun exampleAsymmetricCache(): KvCacheStore { + return KvCacheStore.turboQuant( + numLayers = 32, + numHeads = 8, // GQA: 8 KV heads + headDim = 128, + maxSeqLen = 8192, + keyBits = 8, // High precision for keys + valueBits = 4 // Lower precision for values + ) + } + + /** + * Example: Full generation loop with TurboQuant KV cache. + * + * Shows how TurboQuant integrates into token-by-token inference. + */ + public fun exampleGenerationLoop() { + val numLayers = 4 + val numHeads = 4 + val headDim = 64 + val maxSeqLen = 128 + + // Create compressed cache + val cache = KvCacheStore.turboQuant("balanced", numLayers, numHeads, headDim, maxSeqLen) + val bridge = CompressedKvAttention(cache) + + // Simulate generation of 10 tokens + for (token in 0 until 10) { + for (layer in 0 until numLayers) { + // Simulate K/V projections (in real code, this comes from linear layers) + val key = FloatArray(numHeads * headDim) { it.toFloat() / (numHeads * headDim) } + val value = FloatArray(numHeads * headDim) { -it.toFloat() / (numHeads * headDim) } + + // Store with TurboQuant compression (transparent) + bridge.storeKeyValue(layer, key, value) + + // Read decompressed K/V for attention + val cachedKeys = bridge.loadKeysForAttention(layer) + val cachedValues = bridge.loadValuesForAttention(layer) + + // ... pass to scaledDotProductAttention ... + } + } + + // Check compression + val report = cache.memoryReport() + val savedBytes = report.totalLogicalBytes - report.totalPhysicalBytes + // With balanced preset: ~8x compression + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/ActiveMemoryTracker.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/ActiveMemoryTracker.kt new file mode 100644 index 00000000..8c79c90d --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/ActiveMemoryTracker.kt @@ -0,0 +1,20 @@ +package sk.ainet.lang.tensor.storage + +/** + * Global hook for the active [MemoryTracker]. + * + * Set [current] to a tracker instance to automatically capture copy events + * from instrumented copy paths (e.g. CopyMaterializationStrategy, + * DenseTensorDataFactory.from*Array). Set to `null` to disable tracking. + * + * Thread-safety note: on JVM this should ideally be a ThreadLocal. + * For now, a simple global works for single-threaded inference. + */ +public object ActiveMemoryTracker { + public var current: MemoryTracker? = null + + /** Record a copy event on the active tracker, if any. */ + public fun recordCopy(source: String, bytes: Long) { + current?.recordCopy(source, bytes) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferAccessor.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferAccessor.kt new file mode 100644 index 00000000..85338607 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferAccessor.kt @@ -0,0 +1,112 @@ +package sk.ainet.lang.tensor.storage + +/** + * Provides byte-level read access to a [BufferHandle], regardless of its + * ownership mode. + * + * This is the bridge between the storage model (which describes *where* + * bytes live) and code that needs to actually read those bytes. For + * [BufferHandle.Owned] and [BufferHandle.Borrowed], access is direct. + * For [BufferHandle.FileBacked], a platform-specific resolver maps the + * file region into memory. + */ +public interface BufferAccessor : AutoCloseable { + + /** Total accessible bytes. */ + public val sizeInBytes: Long + + /** Read a single byte at [offset]. */ + public fun readByte(offset: Long): Byte + + /** Read [length] bytes starting at [offset]. */ + public fun readBytes(offset: Long, length: Int): ByteArray + + /** Read all bytes into a new array. Only practical for small buffers. */ + public fun readAllBytes(): ByteArray = readBytes(0, sizeInBytes.toInt()) +} + +/** + * Resolves a [BufferHandle] into a [BufferAccessor] that can read the + * underlying bytes. Platform-specific implementations handle file-backed + * and device-resident buffers; heap-backed handles are resolved generically. + */ +public interface BufferResolver { + + /** + * Open a [BufferAccessor] for the given handle. + * The caller is responsible for closing the returned accessor. + */ + public fun resolve(handle: BufferHandle): BufferAccessor +} + +/** + * Default resolver that handles heap-backed handles directly and + * delegates file-backed handles to a [fileBackedResolver]. + */ +public class DefaultBufferResolver( + private val fileBackedResolver: ((BufferHandle.FileBacked) -> BufferAccessor)? = null +) : BufferResolver { + + override fun resolve(handle: BufferHandle): BufferAccessor = when (handle) { + is BufferHandle.Owned -> ByteArrayAccessor(handle.data, handle.offset, handle.sizeInBytes) + is BufferHandle.Borrowed -> ByteArrayAccessor(handle.data, handle.offset, handle.sizeInBytes) + is BufferHandle.Aliased -> resolve(handle.parent).sliced(handle.byteOffset, handle.sizeInBytes) + is BufferHandle.FileBacked -> { + fileBackedResolver?.invoke(handle) + ?: throw UnsupportedOperationException( + "No file-backed resolver configured. Cannot access ${handle.path}" + ) + } + is BufferHandle.DeviceResident -> throw UnsupportedOperationException( + "Cannot resolve device-resident buffer ${handle.deviceId} on host" + ) + } +} + +/** [BufferAccessor] over a plain [ByteArray]. */ +public class ByteArrayAccessor( + private val data: ByteArray, + private val offset: Int = 0, + override val sizeInBytes: Long = (data.size - offset).toLong() +) : BufferAccessor { + + override fun readByte(offset: Long): Byte { + require(offset in 0 until sizeInBytes) { "Offset out of bounds: $offset" } + return data[this.offset + offset.toInt()] + } + + override fun readBytes(offset: Long, length: Int): ByteArray { + require(offset >= 0 && offset + length <= sizeInBytes) { + "Range out of bounds: offset=$offset length=$length size=$sizeInBytes" + } + return data.copyOfRange(this.offset + offset.toInt(), this.offset + offset.toInt() + length) + } + + override fun readAllBytes(): ByteArray { + return if (offset == 0 && sizeInBytes.toInt() == data.size) data + else data.copyOfRange(offset, offset + sizeInBytes.toInt()) + } + + override fun close() {} // no-op for heap arrays + + /** Create a sub-accessor without copying. */ + public fun sliced(byteOffset: Long, size: Long): ByteArrayAccessor = + ByteArrayAccessor(data, offset + byteOffset.toInt(), size) +} + +/** Helper to create a sliced accessor from any accessor. */ +private fun BufferAccessor.sliced(byteOffset: Long, size: Long): BufferAccessor { + if (this is ByteArrayAccessor) return this.sliced(byteOffset, size) + // Fallback: wrap in a delegating accessor + return SlicedAccessor(this, byteOffset, size) +} + +private class SlicedAccessor( + private val parent: BufferAccessor, + private val baseOffset: Long, + override val sizeInBytes: Long +) : BufferAccessor { + override fun readByte(offset: Long): Byte = parent.readByte(baseOffset + offset) + override fun readBytes(offset: Long, length: Int): ByteArray = parent.readBytes(baseOffset + offset, length) + override fun close() {} // parent owns lifecycle +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferHandle.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferHandle.kt new file mode 100644 index 00000000..9d1777bb --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferHandle.kt @@ -0,0 +1,111 @@ +package sk.ainet.lang.tensor.storage + +/** + * Ownership / residency mode of a tensor's backing memory. + * + * Every [TensorStorage] holds a [BufferHandle] that describes *how* the + * runtime acquired the bytes and therefore what operations are legal: + * + * | Mode | Mutable? | Runtime owns memory? | Can outlive source? | + * |-----------------|----------|----------------------|---------------------| + * | [Owned] | yes | yes | yes | + * | [Borrowed] | no* | no | no | + * | [Aliased] | no | no (shared) | tied to parent | + * | [FileBacked] | no | no (OS-managed) | tied to mapping | + * | [DeviceResident]| varies | backend-managed | tied to device ctx | + * + * *Borrowed buffers expose the original array but callers must not mutate it + * unless they know the source permits mutation. + */ +public sealed interface BufferHandle { + + /** Total size in bytes of the accessible region. */ + public val sizeInBytes: Long + + /** Whether this handle permits writing into the buffer. */ + public val isMutable: Boolean + + /** Ownership classification for diagnostics. */ + public val ownership: Ownership + + /** + * Runtime-allocated copy. The runtime owns the underlying memory and is + * free to mutate or release it. + */ + public class Owned( + public val data: ByteArray, + public val offset: Int = 0, + override val sizeInBytes: Long = (data.size - offset).toLong() + ) : BufferHandle { + override val isMutable: Boolean get() = true + override val ownership: Ownership get() = Ownership.OWNED + } + + /** + * A reference to externally-owned memory (e.g. a caller-supplied array). + * The runtime must not free or resize it. Mutation is possible only if + * the source explicitly permits it. + */ + public class Borrowed( + public val data: ByteArray, + public val offset: Int = 0, + override val sizeInBytes: Long = (data.size - offset).toLong(), + override val isMutable: Boolean = false + ) : BufferHandle { + override val ownership: Ownership get() = Ownership.BORROWED + } + + /** + * A slice/view into another [BufferHandle]. Shares the parent's backing + * memory. Mutations (if the parent is mutable) are visible to both. + */ + public class Aliased( + public val parent: BufferHandle, + public val byteOffset: Long, + override val sizeInBytes: Long + ) : BufferHandle { + override val isMutable: Boolean get() = parent.isMutable + override val ownership: Ownership get() = Ownership.ALIASED + + init { + require(byteOffset >= 0) { "byteOffset must be non-negative: $byteOffset" } + require(byteOffset + sizeInBytes <= parent.sizeInBytes) { + "Aliased region ($byteOffset + $sizeInBytes) exceeds parent (${parent.sizeInBytes})" + } + } + } + + /** + * Memory-mapped file region. Immutable from the runtime's perspective + * (the OS manages paging and eviction). + */ + public class FileBacked( + public val path: String, + public val fileOffset: Long, + override val sizeInBytes: Long + ) : BufferHandle { + override val isMutable: Boolean get() = false + override val ownership: Ownership get() = Ownership.FILE_BACKED + } + + /** + * Buffer managed by a compute backend (GPU, NPU, DSP, …). + * Access semantics depend on the backend. + */ + public class DeviceResident( + public val deviceId: String, + public val backendHandle: Any, + override val sizeInBytes: Long, + override val isMutable: Boolean + ) : BufferHandle { + override val ownership: Ownership get() = Ownership.DEVICE_RESIDENT + } +} + +public enum class Ownership { + OWNED, + BORROWED, + ALIASED, + FILE_BACKED, + DEVICE_RESIDENT +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferHandleFactory.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferHandleFactory.kt new file mode 100644 index 00000000..229a85c3 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/BufferHandleFactory.kt @@ -0,0 +1,61 @@ +package sk.ainet.lang.tensor.storage + +/** + * Factory and conversion utilities for creating [BufferHandle] instances + * from common Kotlin types and for slicing existing handles. + */ +public object BufferHandleFactory { + + /** Create an [BufferHandle.Owned] by copying a ByteArray. */ + public fun owned(data: ByteArray): BufferHandle.Owned = + BufferHandle.Owned(data.copyOf()) + + /** Create an [BufferHandle.Owned] from a FloatArray (copies to little-endian bytes). */ + public fun owned(data: FloatArray): BufferHandle.Owned { + val bytes = ByteArray(data.size * 4) + for (i in data.indices) { + val bits = data[i].toRawBits() + val off = i * 4 + bytes[off] = (bits and 0xFF).toByte() + bytes[off + 1] = ((bits shr 8) and 0xFF).toByte() + bytes[off + 2] = ((bits shr 16) and 0xFF).toByte() + bytes[off + 3] = ((bits shr 24) and 0xFF).toByte() + } + return BufferHandle.Owned(bytes) + } + + /** Create an [BufferHandle.Owned] from an IntArray (copies to little-endian bytes). */ + public fun owned(data: IntArray): BufferHandle.Owned { + val bytes = ByteArray(data.size * 4) + for (i in data.indices) { + val v = data[i] + val off = i * 4 + bytes[off] = (v and 0xFF).toByte() + bytes[off + 1] = ((v shr 8) and 0xFF).toByte() + bytes[off + 2] = ((v shr 16) and 0xFF).toByte() + bytes[off + 3] = ((v shr 24) and 0xFF).toByte() + } + return BufferHandle.Owned(bytes) + } + + /** Borrow a ByteArray without copying. Caller must ensure the array outlives the handle. */ + public fun borrow(data: ByteArray, mutable: Boolean = false): BufferHandle.Borrowed = + BufferHandle.Borrowed(data, isMutable = mutable) + + /** Borrow with offset and length. */ + public fun borrow( + data: ByteArray, + offset: Int, + length: Int, + mutable: Boolean = false + ): BufferHandle.Borrowed = + BufferHandle.Borrowed(data, offset = offset, sizeInBytes = length.toLong(), isMutable = mutable) + + /** Create a file-backed handle (metadata only — actual mapping is platform-specific). */ + public fun fileBacked(path: String, offset: Long, size: Long): BufferHandle.FileBacked = + BufferHandle.FileBacked(path, offset, size) + + /** Create an aliased slice of an existing handle. */ + public fun slice(parent: BufferHandle, byteOffset: Long, sizeInBytes: Long): BufferHandle.Aliased = + BufferHandle.Aliased(parent, byteOffset, sizeInBytes) +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/CompressedKvAttention.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/CompressedKvAttention.kt new file mode 100644 index 00000000..21dbdddf --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/CompressedKvAttention.kt @@ -0,0 +1,130 @@ +package sk.ainet.lang.tensor.storage + +/** + * Bridge between [KvCacheStore] and the SDPA execution path. + * + * This abstraction provides the integration point for compressed K/V + * in the attention runtime. Instead of modifying the core [TensorOps] + * interface (which maps to backend-specific fused kernels), this + * component sits between the model layer and SDPA: + * + * 1. **Write path**: Compresses K/V on token append via [storeKeyValue] + * 2. **Read path**: Dequantizes only required tiles via [loadKeysForAttention] + * and [loadValuesForAttention] + * 3. **Extension point**: Backends can override [DequantStrategy] to fuse + * decompression with attention math. + * + * Usage in a transformer layer: + * ```kotlin + * val bridge = CompressedKvAttention(kvCache) + * bridge.storeKeyValue(layer, keyProjection, valueProjection) + * val keys = bridge.loadKeysForAttention(layer) + * val values = bridge.loadValuesForAttention(layer) + * // pass keys, values to scaledDotProductAttention + * ``` + */ +public class CompressedKvAttention( + private val cache: KvCacheStore, + private val dequantStrategy: DequantStrategy = DequantStrategy.FULL_TILE +) { + + /** + * Store K/V projections for a new token, compressing as configured. + * + * @param layer Layer index + * @param key Key projection [numHeads, headDim] + * @param value Value projection [numHeads, headDim] + */ + public fun storeKeyValue(layer: Int, key: FloatArray, value: FloatArray) { + cache.appendToken(layer, key, value) + } + + /** + * Load cached keys for attention, dequantizing as needed. + * + * When the cache uses compressed encoding, this performs + * tile-level decompression. The returned array is shaped + * [numHeads, seqLen, headDim]. + * + * @param layer Layer index + * @param startPos Start of the attention window (inclusive) + * @param endPos End of the attention window (exclusive) + */ + public fun loadKeysForAttention( + layer: Int, + startPos: Int = 0, + endPos: Int = cache.currentSeqLen + ): FloatArray { + return when (dequantStrategy) { + DequantStrategy.FULL_TILE -> cache.readKeys(layer, startPos, endPos) + DequantStrategy.RAW_STORAGE -> { + // For backends that fuse dequant+attention, return raw storage + // and let the caller handle it. Fall back to float for now. + cache.readKeys(layer, startPos, endPos) + } + } + } + + /** + * Load cached values for attention, dequantizing as needed. + * + * @param layer Layer index + * @param startPos Start of the attention window (inclusive) + * @param endPos End of the attention window (exclusive) + */ + public fun loadValuesForAttention( + layer: Int, + startPos: Int = 0, + endPos: Int = cache.currentSeqLen + ): FloatArray { + return when (dequantStrategy) { + DequantStrategy.FULL_TILE -> cache.readValues(layer, startPos, endPos) + DequantStrategy.RAW_STORAGE -> { + cache.readValues(layer, startPos, endPos) + } + } + } + + /** + * Load raw [TensorStorage] for keys, preserving the cache's native encoding. + * + * This is the zero-copy path for backends that can fuse decompression + * with attention computation (e.g., Metal fused dequant+SDPA). + */ + public fun loadKeyStorageRaw( + layer: Int, + startPos: Int = 0, + endPos: Int = cache.currentSeqLen + ): TensorStorage = cache.readKeyStorage(layer, startPos, endPos) + + /** + * Load raw [TensorStorage] for values, preserving native encoding. + */ + public fun loadValueStorageRaw( + layer: Int, + startPos: Int = 0, + endPos: Int = cache.currentSeqLen + ): TensorStorage = cache.readValueStorage(layer, startPos, endPos) + + /** Whether the cache uses compressed (non-Dense) encoding for keys. */ + public val isKeyCompressed: Boolean + get() = cache.keyEncoding !is TensorEncoding.Dense + + /** Whether the cache uses compressed (non-Dense) encoding for values. */ + public val isValueCompressed: Boolean + get() = cache.valueEncoding !is TensorEncoding.Dense + + /** + * Strategy for dequantizing compressed K/V during attention. + */ + public enum class DequantStrategy { + /** Decompress the full requested tile to FP32 before attention. */ + FULL_TILE, + /** + * Return raw compressed storage — the backend is responsible for + * fused dequant+attention. Falls back to [FULL_TILE] when no + * backend fusion is available. + */ + RAW_STORAGE + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/DefaultKvCacheStore.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/DefaultKvCacheStore.kt new file mode 100644 index 00000000..dfd3f42f --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/DefaultKvCacheStore.kt @@ -0,0 +1,184 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape + +/** + * Default KV cache implementation using dense FP32 storage. + * + * This is the reference/baseline implementation that stores K/V as + * uncompressed float arrays. Quantized implementations (Q8_0, TurboQuant) + * will override [appendToken] and [readKeys]/[readValues] with + * encode-on-write / decode-on-read paths. + * + * Internal layout per layer: + * - keys: `FloatArray(numHeads * maxSeqLen * headDim)` — [numHeads, maxSeqLen, headDim] + * - values: `FloatArray(numHeads * maxSeqLen * headDim)` — [numHeads, maxSeqLen, headDim] + * + * Append writes to position [currentSeqLen]; read returns a contiguous slice. + */ +public class DefaultKvCacheStore( + private val config: KvCacheConfig +) : KvCacheStore { + + override val numLayers: Int get() = config.numLayers + override val numHeads: Int get() = config.numHeads + override val headDim: Int get() = config.headDim + override val maxSeqLen: Int get() = config.maxSeqLen + override val keyEncoding: TensorEncoding get() = config.keyEncoding + override val valueEncoding: TensorEncoding get() = config.valueEncoding + override val placement: Placement get() = config.placement + + private var _currentSeqLen: Int = 0 + override val currentSeqLen: Int get() = _currentSeqLen + + // Per-layer storage: keys[layer] and values[layer] + // Each is [numHeads, maxSeqLen, headDim] laid out as contiguous float array + private val keys: Array = Array(numLayers) { + FloatArray(numHeads * maxSeqLen * headDim) + } + private val values: Array = Array(numLayers) { + FloatArray(numHeads * maxSeqLen * headDim) + } + + override fun appendToken(layer: Int, key: FloatArray, value: FloatArray) { + requireLayerIndex(layer) + check(_currentSeqLen < maxSeqLen) { + "KV cache is full: currentSeqLen=$_currentSeqLen, maxSeqLen=$maxSeqLen" + } + require(key.size == numHeads * headDim) { + "Key size mismatch: expected ${numHeads * headDim}, got ${key.size}" + } + require(value.size == numHeads * headDim) { + "Value size mismatch: expected ${numHeads * headDim}, got ${value.size}" + } + + val pos = _currentSeqLen + val layerKeys = keys[layer] + val layerValues = values[layer] + + // Copy each head's slice into the [head, pos, :] position + for (h in 0 until numHeads) { + val srcOffset = h * headDim + val dstOffset = h * maxSeqLen * headDim + pos * headDim + key.copyInto(layerKeys, dstOffset, srcOffset, srcOffset + headDim) + value.copyInto(layerValues, dstOffset, srcOffset, srcOffset + headDim) + } + + // Only increment seqLen when the last layer is written + if (layer == numLayers - 1) { + _currentSeqLen++ + } + } + + override fun readKeys(layer: Int, startPos: Int, endPos: Int): FloatArray { + return readRange(keys[layer], layer, startPos, endPos) + } + + override fun readValues(layer: Int, startPos: Int, endPos: Int): FloatArray { + return readRange(values[layer], layer, startPos, endPos) + } + + override fun readKeyStorage(layer: Int, startPos: Int, endPos: Int): TensorStorage { + return toTensorStorage(readKeys(layer, startPos, endPos), endPos - startPos, keyEncoding) + } + + override fun readValueStorage(layer: Int, startPos: Int, endPos: Int): TensorStorage { + return toTensorStorage(readValues(layer, startPos, endPos), endPos - startPos, valueEncoding) + } + + override fun evict(fromPos: Int) { + require(fromPos in 0..currentSeqLen) { + "evict fromPos=$fromPos out of range [0, $currentSeqLen]" + } + _currentSeqLen = fromPos + // Zero out evicted region for safety (prevents stale reads) + for (layer in 0 until numLayers) { + for (h in 0 until numHeads) { + val offset = h * maxSeqLen * headDim + fromPos * headDim + val count = (maxSeqLen - fromPos) * headDim + keys[layer].fill(0f, offset, offset + count) + values[layer].fill(0f, offset, offset + count) + } + } + } + + override fun clear() { + _currentSeqLen = 0 + for (layer in 0 until numLayers) { + keys[layer].fill(0f) + values[layer].fill(0f) + } + } + + override fun memoryReport(): KvCacheMemoryReport { + val elementsPerLayer = numHeads.toLong() * maxSeqLen * headDim + val logicalBytesPerLayer = elementsPerLayer * 4 // FP32 + return KvCacheMemoryReport( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = maxSeqLen, + currentSeqLen = _currentSeqLen, + keyEncoding = keyEncoding, + valueEncoding = valueEncoding, + placement = placement, + keyPhysicalBytes = numLayers * logicalBytesPerLayer, + valuePhysicalBytes = numLayers * logicalBytesPerLayer, + keyLogicalBytes = numLayers * logicalBytesPerLayer, + valueLogicalBytes = numLayers * logicalBytesPerLayer + ) + } + + // --- Internal helpers --- + + private fun readRange( + layerData: FloatArray, + layer: Int, + startPos: Int, + endPos: Int + ): FloatArray { + requireLayerIndex(layer) + require(startPos in 0..endPos) { "Invalid range: startPos=$startPos, endPos=$endPos" } + require(endPos <= _currentSeqLen) { + "endPos=$endPos exceeds currentSeqLen=$_currentSeqLen" + } + + val seqLen = endPos - startPos + val result = FloatArray(numHeads * seqLen * headDim) + for (h in 0 until numHeads) { + val srcBase = h * maxSeqLen * headDim + startPos * headDim + val dstBase = h * seqLen * headDim + layerData.copyInto(result, dstBase, srcBase, srcBase + seqLen * headDim) + } + return result + } + + private fun toTensorStorage( + data: FloatArray, + seqLen: Int, + encoding: TensorEncoding + ): TensorStorage { + // Convert FloatArray to ByteArray for TensorStorage + val bytes = ByteArray(data.size * 4) + for (i in data.indices) { + val bits = data[i].toRawBits() + bytes[i * 4] = (bits and 0xFF).toByte() + bytes[i * 4 + 1] = ((bits shr 8) and 0xFF).toByte() + bytes[i * 4 + 2] = ((bits shr 16) and 0xFF).toByte() + bytes[i * 4 + 3] = ((bits shr 24) and 0xFF).toByte() + } + return TensorStorage( + shape = Shape(numHeads, seqLen, headDim), + logicalType = LogicalDType.FLOAT32, + encoding = encoding, + buffer = BufferHandle.Owned(bytes), + placement = placement + ) + } + + private fun requireLayerIndex(layer: Int) { + require(layer in 0 until numLayers) { + "Layer index $layer out of range [0, $numLayers)" + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/KvCacheAnnotationResolver.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/KvCacheAnnotationResolver.kt new file mode 100644 index 00000000..d3e6152c --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/KvCacheAnnotationResolver.kt @@ -0,0 +1,98 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantConfig +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantPresets + +/** + * Resolves [KvCache] annotations to [KvCacheStore] instances. + * + * Used by skainet-transformers to create KV caches declaratively. + * When a model layer is annotated with `@KvCache(preset = "balanced")`, + * this resolver creates the appropriate compressed or dense cache. + * + * Example: + * ```kotlin + * // In skainet-transformers attention layer: + * @KvCache(preset = "balanced") + * class SelfAttention(val numHeads: Int, val headDim: Int, ...) { + * val cache = KvCacheAnnotationResolver.resolve( + * annotation = this::class.annotations.filterIsInstance().first(), + * numLayers = modelConfig.numLayers, + * numHeads = numHeads, + * headDim = headDim, + * maxSeqLen = modelConfig.maxSeqLen + * ) + * } + * ``` + */ +public object KvCacheAnnotationResolver { + + /** + * Resolve a [KvCache] annotation to a [KvCacheStore]. + * + * @param annotation The @KvCache annotation values + * @param numLayers Number of transformer layers + * @param numHeads Number of KV heads per layer + * @param headDim Dimension per head + * @param maxSeqLen Maximum sequence length (overridden by annotation if > 0) + */ + public fun resolve( + annotation: KvCache, + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int + ): KvCacheStore { + val effectiveMaxSeqLen = if (annotation.maxSeqLen > 0) annotation.maxSeqLen else maxSeqLen + + return when (annotation.preset) { + "none" -> { + // Custom config from annotation parameters + KvCacheStore.turboQuant( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = effectiveMaxSeqLen, + keyBits = annotation.keyBits, + valueBits = annotation.valueBits, + useQjl = annotation.useQjl + ) + } + "dense" -> { + KvCacheStore.dense(numLayers, numHeads, headDim, effectiveMaxSeqLen) + } + else -> { + // Named preset + KvCacheStore.turboQuant( + preset = annotation.preset, + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = effectiveMaxSeqLen + ) + } + } + } + + /** + * Resolve a preset name string to a [KvCacheStore]. + * + * Convenience for when you have the preset name but not the full annotation. + * + * @param preset "dense", "safe-lowbit", "balanced", or "experimental-max" + * @param numLayers Number of transformer layers + * @param numHeads Number of KV heads per layer + * @param headDim Dimension per head + * @param maxSeqLen Maximum sequence length + */ + public fun resolve( + preset: String, + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int + ): KvCacheStore = when (preset) { + "dense", "none" -> KvCacheStore.dense(numLayers, numHeads, headDim, maxSeqLen) + else -> KvCacheStore.turboQuant(preset, numLayers, numHeads, headDim, maxSeqLen) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/KvCacheStore.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/KvCacheStore.kt new file mode 100644 index 00000000..a0cebebc --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/KvCacheStore.kt @@ -0,0 +1,277 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantConfig +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantPreset +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantPresets + +/** + * Dedicated KV-cache storage abstraction for inference. + * + * Unlike generic [TensorStorage], a KV cache is **append-friendly** and + * **role-aware**: keys and values may use different encodings and bit budgets. + * The cache is addressed by (layer, head, position) and supports compressed + * block storage for quantized formats (Q4_K, Q8_0, TurboQuant, etc.). + * + * Backends and attention kernels interact with the cache through this + * interface rather than managing raw tensors directly. This allows: + * - Compressed K/V writes on token append + * - Tile-level dequantization on read (only the needed range) + * - Asymmetric K/V policies (e.g., Q8_0 for keys, 4-bit for values) + * - Backend-specific fused dequant+attention paths + */ +public interface KvCacheStore { + + /** Number of transformer layers in this cache. */ + public val numLayers: Int + + /** Number of KV heads per layer. */ + public val numHeads: Int + + /** Dimension per head. */ + public val headDim: Int + + /** Maximum sequence length this cache can hold. */ + public val maxSeqLen: Int + + /** Current number of tokens stored in the cache. */ + public val currentSeqLen: Int + + /** Encoding used for key storage. */ + public val keyEncoding: TensorEncoding + + /** Encoding used for value storage. */ + public val valueEncoding: TensorEncoding + + /** Placement intent for the cache buffers. */ + public val placement: Placement + + /** + * Append a single token's K/V projections for one layer. + * + * The runtime calls this once per layer per generated token. The cache + * is responsible for encoding/compressing the data according to + * [keyEncoding] and [valueEncoding]. + * + * @param layer Layer index (0-based) + * @param key Key projection [numHeads, headDim] as float + * @param value Value projection [numHeads, headDim] as float + * @throws IllegalStateException if the cache is full ([currentSeqLen] >= [maxSeqLen]) + */ + public fun appendToken(layer: Int, key: FloatArray, value: FloatArray) + + /** + * Read cached keys for a layer, dequantized to float. + * + * Returns the key cache for positions `[startPos, endPos)` as a + * contiguous float array shaped [numHeads, (endPos - startPos), headDim]. + * + * @param layer Layer index + * @param startPos First token position (inclusive) + * @param endPos Last token position (exclusive), defaults to [currentSeqLen] + */ + public fun readKeys(layer: Int, startPos: Int = 0, endPos: Int = currentSeqLen): FloatArray + + /** + * Read cached values for a layer, dequantized to float. + * + * Returns the value cache for positions `[startPos, endPos)` as a + * contiguous float array shaped [numHeads, (endPos - startPos), headDim]. + * + * @param layer Layer index + * @param startPos First token position (inclusive) + * @param endPos Last token position (exclusive), defaults to [currentSeqLen] + */ + public fun readValues(layer: Int, startPos: Int = 0, endPos: Int = currentSeqLen): FloatArray + + /** + * Read raw (possibly compressed) key storage for a layer as [TensorStorage]. + * + * This is the zero-copy path for backends that can fuse dequantization + * with attention computation. Returns storage with the cache's native + * [keyEncoding]. + * + * @param layer Layer index + * @param startPos First token position (inclusive) + * @param endPos Last token position (exclusive) + */ + public fun readKeyStorage(layer: Int, startPos: Int = 0, endPos: Int = currentSeqLen): TensorStorage + + /** + * Read raw (possibly compressed) value storage for a layer as [TensorStorage]. + * + * @param layer Layer index + * @param startPos First token position (inclusive) + * @param endPos Last token position (exclusive) + */ + public fun readValueStorage(layer: Int, startPos: Int = 0, endPos: Int = currentSeqLen): TensorStorage + + /** + * Evict all cached tokens from position [fromPos] onward. + * + * Used for sequence truncation or speculative decoding rollback. + * After eviction, [currentSeqLen] becomes [fromPos]. + */ + public fun evict(fromPos: Int) + + /** Reset the cache, clearing all stored tokens. */ + public fun clear() + + /** + * Memory report for the entire cache. + */ + public fun memoryReport(): KvCacheMemoryReport + + public companion object { + /** + * Create an uncompressed FP32 KV cache (baseline). + * + * Use this when you don't need compression or as a reference + * for quality comparison. + */ + public fun dense( + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int + ): KvCacheStore = DefaultKvCacheStore( + KvCacheConfig.dense(numLayers, numHeads, headDim, maxSeqLen) + ) + + /** + * Create a TurboQuant-compressed KV cache from a named preset. + * + * Available presets: "safe-lowbit", "balanced", "experimental-max". + * + * Example: + * ```kotlin + * val cache = KvCacheStore.turboQuant("balanced", numLayers=32, numHeads=32, headDim=128, maxSeqLen=4096) + * ``` + * + * @param preset Preset name (see [TurboQuantPresets.availablePresets]) + * @param numLayers Number of transformer layers + * @param numHeads Number of KV heads per layer + * @param headDim Dimension per head + * @param maxSeqLen Maximum sequence length + */ + public fun turboQuant( + preset: String, + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int + ): KvCacheStore { + val resolved = TurboQuantPresets.forModel(preset, numLayers, numHeads, headDim, maxSeqLen) + return fromPreset(resolved) + } + + /** + * Create a TurboQuant-compressed KV cache with custom bit budgets. + * + * Example: + * ```kotlin + * // 8-bit keys, 4-bit values (safe-lowbit style) + * val cache = KvCacheStore.turboQuant( + * numLayers=32, numHeads=32, headDim=128, maxSeqLen=4096, + * keyBits=8, valueBits=4 + * ) + * ``` + */ + public fun turboQuant( + numLayers: Int, + numHeads: Int, + headDim: Int, + maxSeqLen: Int, + keyBits: Int = 4, + valueBits: Int = 4, + useQjl: Boolean = false + ): KvCacheStore { + val config = KvCacheConfig( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = maxSeqLen, + keyEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = keyBits), + valueEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = valueBits) + ) + val keyConfig = if (useQjl) TurboQuantConfig.polarPlusQjl(bits = keyBits) + else TurboQuantConfig.polarOnly(bits = keyBits) + val valueConfig = if (useQjl) TurboQuantConfig.polarPlusQjl(bits = valueBits) + else TurboQuantConfig.polarOnly(bits = valueBits) + return TurboQuantKvCacheStore(config, keyConfig, valueConfig) + } + + /** + * Create a KV cache from a [TurboQuantPreset]. + */ + public fun fromPreset(preset: TurboQuantPreset): KvCacheStore { + val keyConfig = preset.keyQuantConfig ?: TurboQuantConfig.polarOnly(bits = 4) + val valueConfig = preset.valueQuantConfig ?: TurboQuantConfig.polarOnly(bits = 4) + return TurboQuantKvCacheStore(preset.cacheConfig, keyConfig, valueConfig) + } + } +} + +/** + * Configuration for asymmetric K/V encoding policies. + * + * Keys are often more quality-sensitive than values, so different + * bit budgets may be appropriate. For example: + * - safe-lowbit: Q8_0 keys + 4-bit values + * - balanced: 4-bit keys + 4-bit values + */ +public data class KvCacheConfig( + val numLayers: Int, + val numHeads: Int, + val headDim: Int, + val maxSeqLen: Int, + val keyEncoding: TensorEncoding = TensorEncoding.Dense(4), + val valueEncoding: TensorEncoding = TensorEncoding.Dense(4), + val placement: Placement = Placement.CPU_HEAP.copy(residency = Residency.PERSISTENT) +) { + init { + require(numLayers > 0) { "numLayers must be positive: $numLayers" } + require(numHeads > 0) { "numHeads must be positive: $numHeads" } + require(headDim > 0) { "headDim must be positive: $headDim" } + require(maxSeqLen > 0) { "maxSeqLen must be positive: $maxSeqLen" } + } + + public companion object { + /** Uncompressed FP32 cache (baseline). */ + public fun dense(numLayers: Int, numHeads: Int, headDim: Int, maxSeqLen: Int): KvCacheConfig = + KvCacheConfig(numLayers, numHeads, headDim, maxSeqLen) + + /** Q8_0-compressed cache for both K and V. */ + public fun q8(numLayers: Int, numHeads: Int, headDim: Int, maxSeqLen: Int): KvCacheConfig = + KvCacheConfig( + numLayers, numHeads, headDim, maxSeqLen, + keyEncoding = TensorEncoding.Q8_0, + valueEncoding = TensorEncoding.Q8_0 + ) + } +} + +/** + * Memory report for a KV cache instance. + */ +public data class KvCacheMemoryReport( + val numLayers: Int, + val numHeads: Int, + val headDim: Int, + val maxSeqLen: Int, + val currentSeqLen: Int, + val keyEncoding: TensorEncoding, + val valueEncoding: TensorEncoding, + val placement: Placement, + val keyPhysicalBytes: Long, + val valuePhysicalBytes: Long, + val keyLogicalBytes: Long, + val valueLogicalBytes: Long +) { + val totalPhysicalBytes: Long get() = keyPhysicalBytes + valuePhysicalBytes + val totalLogicalBytes: Long get() = keyLogicalBytes + valueLogicalBytes + val compressionRatio: Double + get() = if (totalPhysicalBytes > 0) totalLogicalBytes.toDouble() / totalPhysicalBytes else 1.0 + val utilizationRatio: Double + get() = if (maxSeqLen > 0) currentSeqLen.toDouble() / maxSeqLen else 0.0 +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/LogicalDType.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/LogicalDType.kt new file mode 100644 index 00000000..c42be6f0 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/LogicalDType.kt @@ -0,0 +1,66 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.FP16 +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.FP64 +import sk.ainet.lang.types.Int16 +import sk.ainet.lang.types.Int32 +import sk.ainet.lang.types.Int4 +import sk.ainet.lang.types.Int64 +import sk.ainet.lang.types.Int8 +import sk.ainet.lang.types.UInt8 +import sk.ainet.lang.types.UInt16 +import sk.ainet.lang.types.UInt32 +import sk.ainet.lang.types.UInt64 +import sk.ainet.lang.types.Ternary + +/** + * Logical numeric type — what the tensor values mean semantically. + * + * This is intentionally separate from [TensorEncoding], which describes how + * values are physically stored. A tensor with logical type [FLOAT32] might + * be encoded as [TensorEncoding.Dense], [TensorEncoding.Q4_K], etc. + */ +public enum class LogicalDType( + public val sizeInBits: Int, + public val isFloatingPoint: Boolean, + public val isSigned: Boolean +) { + TERNARY(2, isFloatingPoint = false, isSigned = true), + INT4(4, isFloatingPoint = false, isSigned = true), + INT8(8, isFloatingPoint = false, isSigned = true), + INT16(16, isFloatingPoint = false, isSigned = true), + INT32(32, isFloatingPoint = false, isSigned = true), + INT64(64, isFloatingPoint = false, isSigned = true), + UINT8(8, isFloatingPoint = false, isSigned = false), + UINT16(16, isFloatingPoint = false, isSigned = false), + UINT32(32, isFloatingPoint = false, isSigned = false), + UINT64(64, isFloatingPoint = false, isSigned = false), + FLOAT16(16, isFloatingPoint = true, isSigned = true), + BFLOAT16(16, isFloatingPoint = true, isSigned = true), + FLOAT32(32, isFloatingPoint = true, isSigned = true), + FLOAT64(64, isFloatingPoint = true, isSigned = true); + + public val sizeInBytes: Int get() = (sizeInBits + 7) / 8 + + public companion object { + public fun fromDType(dtype: DType): LogicalDType = when (dtype) { + is Ternary -> TERNARY + is Int4 -> INT4 + is Int8 -> INT8 + is Int16 -> INT16 + is Int32 -> INT32 + is Int64 -> INT64 + is UInt8 -> UINT8 + is UInt16 -> UINT16 + is UInt32 -> UINT32 + is UInt64 -> UINT64 + is FP16 -> FLOAT16 + is BF16 -> BFLOAT16 + is FP32 -> FLOAT32 + is FP64 -> FLOAT64 + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/MemoryPlanner.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/MemoryPlanner.kt new file mode 100644 index 00000000..28104323 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/MemoryPlanner.kt @@ -0,0 +1,95 @@ +package sk.ainet.lang.tensor.storage + +/** + * Resolves [Placement] intent into concrete buffer allocation decisions. + * + * The planner inspects available backends and decides: + * - Where a tensor should actually live (device + memory domain) + * - Whether a fallback is needed (e.g. GPU not available → CPU) + * - Whether immutable weights should be file-backed vs heap-copied + * + * Currently only the CPU backend is wired in, so the planner always + * resolves to CPU/HOST_HEAP or CPU/MMAP_FILE. GPU/NPU resolution + * will be added when those backends ship. + */ +public class MemoryPlanner( + private val availableDevices: Set = setOf(DeviceKind.CPU) +) { + + /** + * Resolve a placement intent to an actual placement that can be satisfied. + * + * @param requested The user/loader-requested placement + * @return A [ResolvedPlacement] with the actual target and whether fallback was used + */ + public fun resolve(requested: Placement): ResolvedPlacement { + val targetDevice = if (requested.device == DeviceKind.AUTO) { + bestAvailableDevice() + } else { + requested.device + } + + return if (targetDevice in availableDevices) { + ResolvedPlacement( + actual = requested.copy(device = targetDevice), + usedFallback = false + ) + } else if (requested.requirement == Requirement.REQUIRED) { + throw PlacementUnavailableException( + "Required device $targetDevice is not available. Available: $availableDevices" + ) + } else { + // Fallback to the placement's specified fallback device + val fallbackDevice = if (requested.fallback in availableDevices) { + requested.fallback + } else { + DeviceKind.CPU + } + ResolvedPlacement( + actual = Placement( + device = fallbackDevice, + domain = fallbackDomain(requested.domain, fallbackDevice), + residency = requested.residency, + requirement = requested.requirement, + fallback = requested.fallback + ), + usedFallback = true + ) + } + } + + /** + * Suggest the best placement for a weight tensor. + * File-backed if persistent, heap if transient. + */ + public fun suggestWeightPlacement(isFileBacked: Boolean): Placement { + return if (isFileBacked) Placement.MMAP_WEIGHTS else Placement.CPU_HEAP.copy(residency = Residency.PERSISTENT) + } + + /** + * Suggest placement for a mutable activation/intermediate tensor. + */ + public fun suggestActivationPlacement(): Placement = Placement.CPU_HEAP + + private fun bestAvailableDevice(): DeviceKind = when { + DeviceKind.GPU in availableDevices -> DeviceKind.GPU + DeviceKind.NPU in availableDevices -> DeviceKind.NPU + else -> DeviceKind.CPU + } + + private fun fallbackDomain(requested: MemoryDomain, device: DeviceKind): MemoryDomain { + // If falling back to CPU, translate device-specific domains to host domains + return when { + device == DeviceKind.CPU && requested == MemoryDomain.DEVICE_LOCAL -> MemoryDomain.HOST_HEAP + device == DeviceKind.CPU && requested == MemoryDomain.UNIFIED -> MemoryDomain.HOST_HEAP + else -> requested + } + } +} + +public data class ResolvedPlacement( + val actual: Placement, + val usedFallback: Boolean +) + +public class PlacementUnavailableException(message: String) : RuntimeException(message) diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/MemoryTracker.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/MemoryTracker.kt new file mode 100644 index 00000000..e723748f --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/MemoryTracker.kt @@ -0,0 +1,113 @@ +package sk.ainet.lang.tensor.storage + +/** + * Tracks memory allocation events and reports aggregate statistics + * across all live [TensorStorage] instances. + * + * Use [record] to log storage creation, and [report] to get a snapshot + * of current memory usage. This is primarily for debugging and regression + * testing (e.g., "assert no unexpected copies in this inference pass"). + */ +public class MemoryTracker { + + private val entries = mutableListOf() + private var copyCount: Long = 0 + private var copyBytes: Long = 0 + + /** Record a tensor storage allocation. */ + public fun record(name: String, storage: TensorStorage) { + entries.add(TrackedEntry(name, storage.memoryReport())) + } + + /** Record an explicit copy event (for copy-tracing). */ + public fun recordCopy(sourceName: String, bytes: Long) { + copyCount++ + copyBytes += bytes + } + + /** Reset all tracked entries. */ + public fun clear() { + entries.clear() + copyCount = 0 + copyBytes = 0 + } + + /** Generate an aggregate memory report. */ + public fun report(): AggregateMemoryReport { + var totalLogical = 0L + var totalPhysical = 0L + var fileBackedBytes = 0L + var aliasedCount = 0 + var ownedCount = 0 + var borrowedCount = 0 + var fileBackedCount = 0 + + for (entry in entries) { + val r = entry.report + totalLogical += r.logicalBytes + totalPhysical += r.physicalBytes + if (r.isFileBacked) { + fileBackedBytes += r.physicalBytes + fileBackedCount++ + } + if (r.isAlias) aliasedCount++ + when (r.ownership) { + Ownership.OWNED -> ownedCount++ + Ownership.BORROWED -> borrowedCount++ + else -> {} + } + } + + return AggregateMemoryReport( + tensorCount = entries.size, + totalLogicalBytes = totalLogical, + totalPhysicalBytes = totalPhysical, + fileBackedBytes = fileBackedBytes, + ownedCount = ownedCount, + borrowedCount = borrowedCount, + aliasedCount = aliasedCount, + fileBackedCount = fileBackedCount, + copyCount = copyCount, + copyBytes = copyBytes, + entries = entries.toList() + ) + } +} + +public data class TrackedEntry( + val name: String, + val report: StorageMemoryReport +) + +public data class AggregateMemoryReport( + val tensorCount: Int, + val totalLogicalBytes: Long, + val totalPhysicalBytes: Long, + val fileBackedBytes: Long, + val ownedCount: Int, + val borrowedCount: Int, + val aliasedCount: Int, + val fileBackedCount: Int, + val copyCount: Long, + val copyBytes: Long, + val entries: List +) { + val overallCompressionRatio: Double + get() = if (totalPhysicalBytes > 0) totalLogicalBytes.toDouble() / totalPhysicalBytes else 1.0 + + override fun toString(): String = buildString { + appendLine("=== Memory Report ===") + appendLine("Tensors: $tensorCount") + appendLine("Logical: $totalLogicalBytes bytes") + appendLine("Physical: $totalPhysicalBytes bytes") + appendLine("File-backed: $fileBackedCount ($fileBackedBytes bytes)") + appendLine("Owned: $ownedCount, Borrowed: $borrowedCount, Aliased: $aliasedCount") + appendLine("Copies: $copyCount ($copyBytes bytes)") + if (entries.isNotEmpty()) { + appendLine("--- Per-tensor ---") + for (e in entries) { + appendLine(" ${e.name}: ${e.report}") + } + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/PackedBlockStorage.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/PackedBlockStorage.kt new file mode 100644 index 00000000..b8b37fc9 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/PackedBlockStorage.kt @@ -0,0 +1,76 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape + +/** + * Shared contract for all packed/quantized block tensor storage formats. + * + * Instead of each quantization format (Q4_K, Q8_0, Ternary, …) inventing + * its own loader, planner, and backend handling path, all packed formats + * implement this interface. Backends and planners can dispatch on + * [encoding] without knowing every possible quantization scheme. + * + * Individual formats still expose format-specific accessors (sub-block + * scales, code extraction, etc.) through their own sub-interfaces. + */ +public interface PackedBlockStorage { + + /** The logical shape of the tensor (element count, not block count). */ + public val shape: Shape + + /** The physical encoding describing the block layout. */ + public val encoding: TensorEncoding + + /** Number of blocks in this storage. */ + public val blockCount: Int + + /** Number of logical elements per block. */ + public val blockSize: Int + + /** Raw packed byte data containing all blocks. */ + public val packedData: ByteArray + + /** Physical byte size of the packed data. */ + public val physicalBytes: Long get() = packedData.size.toLong() + + /** Logical element count. */ + public val elementCount: Long get() = shape.volume.toLong() + + /** + * Dequantize a single block to float values. + * + * @param blockIdx The block index (0-based) + * @param output Destination array (must have at least [blockSize] elements from [outputOffset]) + * @param outputOffset Starting index in [output] + */ + public fun dequantizeBlock(blockIdx: Int, output: FloatArray, outputOffset: Int = 0) + + /** + * Dequantize the entire tensor to a FloatArray. + * Default implementation calls [dequantizeBlock] for each block. + */ + public fun toFloatArray(): FloatArray { + val result = FloatArray(shape.volume) + var offset = 0 + for (i in 0 until blockCount) { + val remaining = shape.volume - offset + dequantizeBlock(i, result, offset) + offset += minOf(blockSize, remaining) + } + return result + } + + /** + * Convert this packed storage to a [TensorStorage] descriptor. + */ + public fun toTensorStorage( + logicalType: LogicalDType = LogicalDType.FLOAT32, + placement: Placement = Placement.CPU_HEAP + ): TensorStorage = TensorStorage( + shape = shape, + logicalType = logicalType, + encoding = encoding, + buffer = BufferHandle.Borrowed(packedData, isMutable = false), + placement = placement + ) +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/Placement.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/Placement.kt new file mode 100644 index 00000000..2234fec8 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/Placement.kt @@ -0,0 +1,79 @@ +package sk.ainet.lang.tensor.storage + +/** + * High-level placement descriptor: where a tensor lives and how the runtime + * should manage it. + * + * Placement is *intent* — it tells the planner what to aim for but does not + * encode backend scratch-memory details. The planner resolves placement to + * a concrete [BufferHandle] and falls back if the preferred target is + * unavailable. + */ +public data class Placement( + val device: DeviceKind = DeviceKind.CPU, + val domain: MemoryDomain = MemoryDomain.HOST_HEAP, + val residency: Residency = Residency.PERSISTENT, + val requirement: Requirement = Requirement.PREFERRED, + val fallback: DeviceKind = DeviceKind.CPU +) { + public companion object { + /** Default CPU heap placement for mutable runtime buffers. */ + public val CPU_HEAP: Placement = Placement( + device = DeviceKind.CPU, + domain = MemoryDomain.HOST_HEAP, + residency = Residency.TRANSIENT, + requirement = Requirement.PREFERRED + ) + + /** File-backed placement for immutable model weights. */ + public val MMAP_WEIGHTS: Placement = Placement( + device = DeviceKind.CPU, + domain = MemoryDomain.MMAP_FILE, + residency = Residency.PERSISTENT, + requirement = Requirement.PREFERRED + ) + + /** GPU-preferred placement with CPU fallback. */ + public val GPU_PREFERRED: Placement = Placement( + device = DeviceKind.GPU, + domain = MemoryDomain.DEVICE_LOCAL, + residency = Residency.PERSISTENT, + requirement = Requirement.PREFERRED, + fallback = DeviceKind.CPU + ) + } +} + +public enum class DeviceKind { + AUTO, + CPU, + GPU, + NPU +} + +public enum class MemoryDomain { + /** Standard JVM / native heap allocation. */ + HOST_HEAP, + /** Pinned (non-pageable) host memory for fast DMA transfers. */ + HOST_PINNED, + /** Memory-mapped file (immutable, OS-paged). */ + MMAP_FILE, + /** Unified / shared memory visible to both host and device. */ + UNIFIED, + /** Device-local memory (fastest for compute, not directly host-accessible). */ + DEVICE_LOCAL +} + +public enum class Residency { + /** Short-lived: activations, temporaries, intermediate results. */ + TRANSIENT, + /** Long-lived: model weights, embeddings, caches. */ + PERSISTENT +} + +public enum class Requirement { + /** Best-effort: fall back to [Placement.fallback] if unavailable. */ + PREFERRED, + /** Hard requirement: fail if the target is unavailable. */ + REQUIRED +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/PlacementAnnotations.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/PlacementAnnotations.kt new file mode 100644 index 00000000..5ac0880e --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/PlacementAnnotations.kt @@ -0,0 +1,123 @@ +package sk.ainet.lang.tensor.storage + +/** + * Declares placement intent for a tensor parameter or property. + * + * The [MemoryPlanner] reads these annotations (via reflection or codegen) + * to decide where tensors should be allocated. This expresses *intent*, + * not a hard guarantee — the planner may fall back if the target is + * unavailable and [requirement] is [Requirement.PREFERRED]. + * + * Example: + * ```kotlin + * @Place(device = DeviceKind.GPU, memory = MemoryDomain.DEVICE_LOCAL) + * val projectionWeight: Tensor + * ``` + */ +@Target(AnnotationTarget.PROPERTY, AnnotationTarget.VALUE_PARAMETER, AnnotationTarget.FIELD) +@Retention(AnnotationRetention.RUNTIME) +public annotation class Place( + val device: DeviceKind = DeviceKind.AUTO, + val memory: MemoryDomain = MemoryDomain.HOST_HEAP, + val requirement: Requirement = Requirement.PREFERRED +) + +/** + * Marks a tensor as an immutable weight that should be file-backed + * (memory-mapped) when possible. + * + * Equivalent to `@Place(device = CPU, memory = MMAP_FILE)` with + * [Residency.PERSISTENT]. The planner treats these tensors as + * read-only and long-lived, preferring OS-paged file access over + * heap allocation. + * + * Example: + * ```kotlin + * @Weights + * val embeddings: Tensor + * + * @Weights(memory = MemoryDomain.HOST_HEAP) // force heap for small weights + * val biasVector: Tensor + * ``` + */ +@Target(AnnotationTarget.PROPERTY, AnnotationTarget.VALUE_PARAMETER, AnnotationTarget.FIELD) +@Retention(AnnotationRetention.RUNTIME) +public annotation class Weights( + val memory: MemoryDomain = MemoryDomain.MMAP_FILE +) + +/** + * Configures TurboQuant KV-cache compression for an attention layer. + * + * Applied to attention layer properties to declare KV-cache compression + * settings. The runtime uses these annotations to configure the + * [KvCacheStore] and [CompressedKvAttention] for each layer. + * + * Example: + * ```kotlin + * @KvCache(preset = "balanced") + * val selfAttention: MultiHeadAttention + * + * @KvCache(keyBits = 8, valueBits = 4) + * val crossAttention: MultiHeadAttention + * + * @KvCache(preset = "safe-lowbit", maxSeqLen = 4096) + * val longContextAttention: MultiHeadAttention + * ``` + */ +@Target(AnnotationTarget.PROPERTY, AnnotationTarget.VALUE_PARAMETER, AnnotationTarget.FIELD) +@Retention(AnnotationRetention.RUNTIME) +public annotation class KvCache( + /** + * Named preset: "safe-lowbit", "balanced", "experimental-max", or "none". + * When set to a named preset, [keyBits] and [valueBits] are ignored. + * Default "none" means no TurboQuant compression (dense FP32 cache). + */ + val preset: String = "none", + + /** + * Bits per element for key compression (2, 3, 4, or 8). + * Only used when [preset] is "none" (custom config). + */ + val keyBits: Int = 4, + + /** + * Bits per element for value compression (2, 3, 4, or 8). + * Only used when [preset] is "none" (custom config). + */ + val valueBits: Int = 4, + + /** + * Whether to use QJL residual for improved inner-product accuracy. + * Only used when [preset] is "none" (custom config). + */ + val useQjl: Boolean = false, + + /** + * Maximum sequence length for the KV cache. + * 0 means use the model's default. + */ + val maxSeqLen: Int = 0, + + /** + * Preferred device for KV cache storage. + */ + val device: DeviceKind = DeviceKind.AUTO +) + +/** + * Disables TurboQuant compression for a specific layer. + * + * When applied alongside a model-level [KvCache] annotation, this + * overrides the compression setting for individual layers that are + * sensitive to quantization (e.g., early layers or cross-attention). + * + * Example: + * ```kotlin + * @KvCacheBypass + * val firstLayerAttention: MultiHeadAttention // stays FP32 + * ``` + */ +@Target(AnnotationTarget.PROPERTY, AnnotationTarget.VALUE_PARAMETER, AnnotationTarget.FIELD) +@Retention(AnnotationRetention.RUNTIME) +public annotation class KvCacheBypass diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/StorageMemoryReport.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/StorageMemoryReport.kt new file mode 100644 index 00000000..07963c52 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/StorageMemoryReport.kt @@ -0,0 +1,42 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape + +/** + * Diagnostic snapshot of a single tensor's memory characteristics. + * + * Used for regression testing (assert no unexpected copies), on-device + * memory budgeting, and debug reporting. + */ +public data class StorageMemoryReport( + val shape: Shape, + val logicalType: LogicalDType, + val encoding: TensorEncoding, + val ownership: Ownership, + val placement: Placement, + val logicalBytes: Long, + val physicalBytes: Long, + val isFileBacked: Boolean, + val isAlias: Boolean, + val isMutable: Boolean +) { + /** Compression ratio: logical / physical. >1 means the encoding is smaller than dense. */ + val compressionRatio: Double + get() = if (physicalBytes > 0) logicalBytes.toDouble() / physicalBytes else 1.0 + + override fun toString(): String = buildString { + append("StorageMemoryReport(") + append("shape=$shape, ") + append("logical=$logicalType, ") + append("encoding=${encoding.name}, ") + append("ownership=$ownership, ") + append("placement=${placement.device}/${placement.domain}, ") + append("logicalBytes=$logicalBytes, ") + append("physicalBytes=$physicalBytes, ") + append("ratio=${((compressionRatio * 100).toLong() / 100.0)}, ") + append("fileBacked=$isFileBacked, ") + append("alias=$isAlias, ") + append("mutable=$isMutable") + append(")") + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/StorageSpec.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/StorageSpec.kt new file mode 100644 index 00000000..8f495131 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/StorageSpec.kt @@ -0,0 +1,62 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.types.DType + +/** + * A storage specification that captures both logical type AND physical + * encoding + placement intent. This enables factory routing that goes + * beyond dtype-only decisions. + * + * [StorageSpec] replaces the pattern of routing only by [DType] (via + * [sk.ainet.lang.tensor.data.TensorFactoryRegistry]). Existing dtype-based + * lookups remain as a convenience — they build a default [StorageSpec] + * with [TensorEncoding.Dense] and [Ownership.OWNED]. + */ +public data class StorageSpec( + val logicalType: LogicalDType, + val encoding: TensorEncoding = TensorEncoding.Dense(logicalType.sizeInBytes), + val ownership: Ownership = Ownership.OWNED, + val placement: Placement = Placement.CPU_HEAP +) { + public companion object { + /** Build a default spec from a legacy DType (dense, owned, CPU heap). */ + public fun fromDType(dtype: DType): StorageSpec = StorageSpec( + logicalType = LogicalDType.fromDType(dtype), + encoding = TensorEncoding.Dense(LogicalDType.fromDType(dtype).sizeInBytes), + ownership = Ownership.OWNED, + placement = Placement.CPU_HEAP + ) + + /** Spec for borrowed dense data. */ + public fun borrowed(dtype: DType): StorageSpec = StorageSpec( + logicalType = LogicalDType.fromDType(dtype), + encoding = TensorEncoding.Dense(LogicalDType.fromDType(dtype).sizeInBytes), + ownership = Ownership.BORROWED, + placement = Placement.CPU_HEAP + ) + + /** Spec for Q4_K packed data. */ + public fun q4k(placement: Placement = Placement.CPU_HEAP): StorageSpec = StorageSpec( + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + ownership = Ownership.BORROWED, + placement = placement + ) + + /** Spec for Q8_0 packed data. */ + public fun q80(placement: Placement = Placement.CPU_HEAP): StorageSpec = StorageSpec( + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q8_0, + ownership = Ownership.BORROWED, + placement = placement + ) + + /** Spec for file-backed weights. */ + public fun mmapWeights(dtype: DType): StorageSpec = StorageSpec( + logicalType = LogicalDType.fromDType(dtype), + encoding = TensorEncoding.Dense(LogicalDType.fromDType(dtype).sizeInBytes), + ownership = Ownership.FILE_BACKED, + placement = Placement.MMAP_WEIGHTS + ) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorEncoding.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorEncoding.kt new file mode 100644 index 00000000..ef1b2894 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorEncoding.kt @@ -0,0 +1,150 @@ +package sk.ainet.lang.tensor.storage + +/** + * Physical storage encoding — how tensor data is laid out in memory. + * + * A [TensorEncoding] describes the byte-level format of a buffer, independent + * of the logical numeric type ([LogicalDType]). For example, a FLOAT32 tensor + * may be stored as [Dense] (4 bytes per element) or as [Q4_K] (packed 4-bit + * blocks with per-block scales). + * + * Encodings are sealed so that pattern-matching in loaders and backends is + * exhaustive and compiler-checked. + */ +public sealed interface TensorEncoding { + + /** Human-readable name for diagnostics and memory reports. */ + public val name: String + + /** + * Physical bytes required to store [elementCount] logical elements + * in this encoding, or `null` if the encoding is opaque/variable. + */ + public fun physicalBytes(elementCount: Long): Long? + + /** Dense element-per-slot layout. One element occupies a fixed number of bytes. */ + public data class Dense(val bytesPerElement: Int) : TensorEncoding { + override val name: String get() = "Dense(${bytesPerElement}B)" + override fun physicalBytes(elementCount: Long): Long = elementCount * bytesPerElement + } + + /** GGML Q4_K block quantization: 256 elements per 144-byte block. */ + public data object Q4_K : TensorEncoding { + public const val BLOCK_SIZE: Int = 256 + public const val BYTES_PER_BLOCK: Int = 144 + + override val name: String get() = "Q4_K" + override fun physicalBytes(elementCount: Long): Long { + val blocks = (elementCount + BLOCK_SIZE - 1) / BLOCK_SIZE + return blocks * BYTES_PER_BLOCK + } + } + + /** GGML Q8_0 block quantization: 32 elements per 34-byte block. */ + public data object Q8_0 : TensorEncoding { + public const val BLOCK_SIZE: Int = 32 + public const val BYTES_PER_BLOCK: Int = 34 + + override val name: String get() = "Q8_0" + override fun physicalBytes(elementCount: Long): Long { + val blocks = (elementCount + BLOCK_SIZE - 1) / BLOCK_SIZE + return blocks * BYTES_PER_BLOCK + } + } + + /** Ternary encoding: 2 bits per element, packed 4 elements per byte. */ + public data object TernaryPacked : TensorEncoding { + override val name: String get() = "Ternary" + override fun physicalBytes(elementCount: Long): Long = + (elementCount + 3) / 4 + } + + /** + * TurboQuant PolarOnly encoding: rotation + scalar quantization + bit-packing. + * + * Backend-friendly variant that omits the QJL residual stage. + * Configurable bits per element (2, 3, 4, or 8). + * + * Block layout: [rotationSeed(4B)] [scales(numGroups * 2B)] [codes(packed bits)] + * + * @param bitsPerElement Number of bits per quantized code (2, 3, 4, or 8) + * @param blockSize Number of elements per block (must be power of 2, typically 64 or 128) + */ + public data class TurboQuantPolar( + val bitsPerElement: Int = 4, + val blockSize: Int = 128 + ) : TensorEncoding { + init { + require(bitsPerElement in setOf(2, 3, 4, 8)) { + "bitsPerElement must be 2, 3, 4, or 8, got $bitsPerElement" + } + require(blockSize > 0 && (blockSize and (blockSize - 1)) == 0) { + "blockSize must be a positive power of 2, got $blockSize" + } + } + + /** Number of quantization groups per block (each group has its own scale). */ + val numGroups: Int get() = blockSize / 32 + + override val name: String get() = "TurboQuant-Polar-${bitsPerElement}b" + + override fun physicalBytes(elementCount: Long): Long { + val blocks = (elementCount + blockSize - 1) / blockSize + val seedBytes = 4L // rotation seed per block + val scaleBytes = numGroups * 2L // FP16 scale per group + val codeBytes = (blockSize.toLong() * bitsPerElement + 7) / 8 // packed codes + return blocks * (seedBytes + scaleBytes + codeBytes) + } + } + + /** + * TurboQuant PolarPlusQjl encoding: rotation + scalar quantization + + * QJL residual + bit-packing. + * + * Closest to the official TurboQuant paper. The QJL residual stage + * preserves inner-product accuracy at the cost of additional storage. + * + * @param bitsPerElement Bits for the primary quantization (2, 3, 4, or 8) + * @param residualBits Bits for the QJL residual (typically 1 or 2) + * @param blockSize Elements per block + */ + public data class TurboQuantPolarQjl( + val bitsPerElement: Int = 4, + val residualBits: Int = 1, + val blockSize: Int = 128 + ) : TensorEncoding { + init { + require(bitsPerElement in setOf(2, 3, 4, 8)) { + "bitsPerElement must be 2, 3, 4, or 8, got $bitsPerElement" + } + require(residualBits in 1..4) { + "residualBits must be 1-4, got $residualBits" + } + require(blockSize > 0 && (blockSize and (blockSize - 1)) == 0) { + "blockSize must be a positive power of 2, got $blockSize" + } + } + + val numGroups: Int get() = blockSize / 32 + + override val name: String + get() = "TurboQuant-PolarQjl-${bitsPerElement}b+${residualBits}r" + + override fun physicalBytes(elementCount: Long): Long { + val blocks = (elementCount + blockSize - 1) / blockSize + val seedBytes = 4L + val scaleBytes = numGroups * 2L + val codeBytes = (blockSize.toLong() * bitsPerElement + 7) / 8 + val residualBytes = (blockSize.toLong() * residualBits + 7) / 8 + return blocks * (seedBytes + scaleBytes + codeBytes + residualBytes) + } + } + + /** + * Opaque / unknown encoding. Used as a fallback for formats the runtime + * cannot yet interpret but still wants to carry through without error. + */ + public data class Opaque(override val name: String, val rawBytes: Long) : TensorEncoding { + override fun physicalBytes(elementCount: Long): Long = rawBytes + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorStorage.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorStorage.kt new file mode 100644 index 00000000..b107b4c2 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorStorage.kt @@ -0,0 +1,155 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape + +/** + * Runtime descriptor for a tensor's backing memory. + * + * [TensorStorage] is the main architectural type that replaces ad-hoc + * array passing between loaders, planners, and backends. It carries enough + * information to handle a tensor without inspecting its bytes: + * + * - **What** the values mean: [logicalType] + * - **How** they are stored: [encoding] + * - **Where** the bytes live: [buffer] + [placement] + * - **Layout**: [shape], [byteOffset], [strides], [isContiguous] + * - **Ownership**: via [buffer]'s [BufferHandle] subtype + * + * Existing [sk.ainet.lang.tensor.data.TensorData] remains as a + * compatibility façade. New loaders, planners, and backends should target + * [TensorStorage] directly. + */ +public data class TensorStorage( + val shape: Shape, + val logicalType: LogicalDType, + val encoding: TensorEncoding, + val buffer: BufferHandle, + val placement: Placement = Placement.CPU_HEAP, + val byteOffset: Long = 0, + val strides: LongArray? = null, + val isContiguous: Boolean = true +) { + /** Number of logical elements in this tensor. */ + val elementCount: Long get() = shape.volume.toLong() + + /** Logical size: number of elements x logical element size. */ + val logicalBytes: Long get() = elementCount * logicalType.sizeInBytes + + /** Physical size: actual bytes consumed in the buffer for this tensor. */ + val physicalBytes: Long get() = encoding.physicalBytes(elementCount) ?: buffer.sizeInBytes + + /** Whether this storage is backed by a memory-mapped file. */ + val isFileBacked: Boolean get() = buffer is BufferHandle.FileBacked + + /** Whether this storage is an alias (view) into another buffer. */ + val isAlias: Boolean get() = buffer is BufferHandle.Aliased + + /** Whether this storage is mutable. */ + val isMutable: Boolean get() = buffer.isMutable + + /** Ownership mode of the backing buffer. */ + val ownership: Ownership get() = buffer.ownership + + /** + * Memory report for this single tensor, useful for diagnostics + * and regression testing. + */ + public fun memoryReport(): StorageMemoryReport = StorageMemoryReport( + shape = shape, + logicalType = logicalType, + encoding = encoding, + ownership = ownership, + placement = placement, + logicalBytes = logicalBytes, + physicalBytes = physicalBytes, + isFileBacked = isFileBacked, + isAlias = isAlias, + isMutable = isMutable + ) + + // --- Explicit transfer operations --- + + /** + * Create a new [TensorStorage] with an owned copy of this storage's data. + * The returned storage is independent of the original buffer. + */ + public fun copyMaterialize(): TensorStorage { + val srcBytes = when (val b = buffer) { + is BufferHandle.Owned -> b.data.copyOfRange(b.offset, b.offset + sizeBytes()) + is BufferHandle.Borrowed -> b.data.copyOfRange(b.offset, b.offset + sizeBytes()) + else -> throw UnsupportedOperationException( + "copyMaterialize not yet supported for ${buffer.ownership} buffers" + ) + } + return copy( + buffer = BufferHandle.Owned(srcBytes), + placement = placement.copy(domain = MemoryDomain.HOST_HEAP) + ) + } + + /** + * Ensure this storage resides on the host (CPU heap). + * If already on host, returns `this`. Otherwise copies to host. + */ + public fun copyToHost(): TensorStorage { + if (placement.device == DeviceKind.CPU && placement.domain == MemoryDomain.HOST_HEAP) return this + return copyMaterialize() + } + + /** + * Request a copy of this storage on the specified device. + * Currently only CPU is supported — GPU/NPU backends will override. + * + * @throws UnsupportedOperationException if the target device has no backend + */ + public fun copyToDevice(device: DeviceKind): TensorStorage { + if (device == DeviceKind.CPU) return copyToHost() + throw UnsupportedOperationException("No backend available for device: $device") + } + + /** + * Re-encode this storage into a different physical encoding. + * Currently a stub — actual transcoding requires backend kernels. + * + * @throws UnsupportedOperationException always (until backends implement this) + */ + public fun repackTo(targetEncoding: TensorEncoding): TensorStorage { + if (encoding == targetEncoding) return this + throw UnsupportedOperationException( + "Repacking from ${encoding.name} to ${targetEncoding.name} is not yet implemented" + ) + } + + private fun sizeBytes(): Int = physicalBytes.toInt() + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is TensorStorage) return false + return shape == other.shape && + logicalType == other.logicalType && + encoding == other.encoding && + buffer == other.buffer && + placement == other.placement && + byteOffset == other.byteOffset && + isContiguous == other.isContiguous && + stridesEqual(strides, other.strides) + } + + override fun hashCode(): Int { + var result = shape.hashCode() + result = 31 * result + logicalType.hashCode() + result = 31 * result + encoding.hashCode() + result = 31 * result + buffer.hashCode() + result = 31 * result + placement.hashCode() + result = 31 * result + byteOffset.hashCode() + result = 31 * result + isContiguous.hashCode() + result = 31 * result + (strides?.contentHashCode() ?: 0) + return result + } + + private fun stridesEqual(a: LongArray?, b: LongArray?): Boolean = when { + a == null && b == null -> true + a != null && b != null -> a.contentEquals(b) + else -> false + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorStorageFactory.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorStorageFactory.kt new file mode 100644 index 00000000..a93b463e --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorStorageFactory.kt @@ -0,0 +1,240 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData +import sk.ainet.lang.tensor.data.DenseIntArrayTensorData +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.tensor.data.IntArrayTensorData +import sk.ainet.lang.tensor.data.Q4_KBlockTensorData +import sk.ainet.lang.tensor.data.Q4_KTensorData +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData +import sk.ainet.lang.tensor.data.Q8_0TensorData +import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 + +/** + * Factory methods for constructing [TensorStorage] from existing SKaiNET types + * and from raw data. These bridge the old TensorData world to the new storage model. + */ +public object TensorStorageFactory { + + /** + * Wrap a FloatArray as owned dense FLOAT32 storage (copies the array). + */ + public fun fromFloatArray(shape: Shape, data: FloatArray): TensorStorage = + TensorStorage( + shape = shape, + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(bytesPerElement = 4), + buffer = BufferHandleFactory.owned(data) + ) + + /** + * Borrow a FloatArray as dense FLOAT32 storage (zero-copy). + */ + public fun borrowFloatArray(shape: Shape, data: FloatArray): TensorStorage { + val bytes = ByteArray(data.size * 4) + for (i in data.indices) { + val bits = data[i].toRawBits() + val off = i * 4 + bytes[off] = (bits and 0xFF).toByte() + bytes[off + 1] = ((bits shr 8) and 0xFF).toByte() + bytes[off + 2] = ((bits shr 16) and 0xFF).toByte() + bytes[off + 3] = ((bits shr 24) and 0xFF).toByte() + } + return TensorStorage( + shape = shape, + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(bytesPerElement = 4), + buffer = BufferHandleFactory.borrow(bytes) + ) + } + + /** + * Wrap an IntArray as owned dense INT32 storage (copies the array). + */ + public fun fromIntArray(shape: Shape, data: IntArray): TensorStorage = + TensorStorage( + shape = shape, + logicalType = LogicalDType.INT32, + encoding = TensorEncoding.Dense(bytesPerElement = 4), + buffer = BufferHandleFactory.owned(data) + ) + + /** + * Create storage from raw bytes with explicit encoding. + * The byte array is borrowed (not copied). + */ + public fun fromRawBytes( + shape: Shape, + logicalType: LogicalDType, + encoding: TensorEncoding, + data: ByteArray, + placement: Placement = Placement.CPU_HEAP + ): TensorStorage = TensorStorage( + shape = shape, + logicalType = logicalType, + encoding = encoding, + buffer = BufferHandleFactory.borrow(data), + placement = placement + ) + + /** + * Create storage from raw bytes with explicit encoding (owned copy). + */ + public fun fromRawBytesOwned( + shape: Shape, + logicalType: LogicalDType, + encoding: TensorEncoding, + data: ByteArray, + placement: Placement = Placement.CPU_HEAP + ): TensorStorage = TensorStorage( + shape = shape, + logicalType = logicalType, + encoding = encoding, + buffer = BufferHandleFactory.owned(data), + placement = placement + ) + + /** + * Create file-backed storage (for memory-mapped model weights). + */ + public fun fileBacked( + shape: Shape, + logicalType: LogicalDType, + encoding: TensorEncoding, + path: String, + fileOffset: Long, + sizeInBytes: Long + ): TensorStorage = TensorStorage( + shape = shape, + logicalType = logicalType, + encoding = encoding, + buffer = BufferHandleFactory.fileBacked(path, fileOffset, sizeInBytes), + placement = Placement.MMAP_WEIGHTS + ) + + /** + * Bridge: create a [TensorStorage] descriptor from an existing [TensorData]. + * + * This inspects the concrete TensorData type and builds the appropriate + * storage descriptor. The underlying data is borrowed (not copied). + */ + public fun fromTensorData(data: TensorData): TensorStorage { + return when (data) { + is FloatArrayTensorData<*> -> TensorStorage( + shape = data.shape, + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(bytesPerElement = 4), + buffer = BufferHandleFactory.owned(data.buffer) + ) + is IntArrayTensorData<*> -> TensorStorage( + shape = data.shape, + logicalType = LogicalDType.INT32, + encoding = TensorEncoding.Dense(bytesPerElement = 4), + buffer = BufferHandleFactory.owned(data.buffer) + ) + is Q4_KTensorData -> TensorStorage( + shape = data.shape, + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + buffer = BufferHandleFactory.borrow(data.packedData) + ) + is Q8_0TensorData -> TensorStorage( + shape = data.shape, + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q8_0, + buffer = BufferHandleFactory.borrow(data.packedData) + ) + else -> { + // Fallback: copy to float array and create dense storage + val floats = data.copyToFloatArray() + fromFloatArray(data.shape, floats) + } + } + } + + /** + * Bridge: create a [TensorData] from a [TensorStorage]. + * + * For dense encodings, this interprets the buffer bytes as float/int arrays. + * For packed encodings (Q4_K, Q8_0), this creates the corresponding packed + * TensorData directly. The underlying bytes are borrowed (not copied) when + * the buffer is Owned or Borrowed. + * + * For [BufferHandle.FileBacked] or [BufferHandle.DeviceResident], a + * [BufferAccessor] must be provided to read the bytes. + * + * @throws UnsupportedOperationException for FileBacked/DeviceResident without accessor + */ + @Suppress("UNCHECKED_CAST") + public fun toTensorData(storage: TensorStorage): TensorData { + val bytes = extractBytes(storage) + + return when (storage.encoding) { + is TensorEncoding.Dense -> when (storage.logicalType) { + LogicalDType.FLOAT32, LogicalDType.FLOAT16, LogicalDType.BFLOAT16 -> { + val floats = bytesToFloatArray(bytes) + DenseFloatArrayTensorData(storage.shape, floats) as TensorData + } + LogicalDType.INT32 -> { + val ints = bytesToIntArray(bytes) + DenseIntArrayTensorData(storage.shape, ints) as TensorData + } + else -> throw UnsupportedOperationException( + "toTensorData not supported for dense ${storage.logicalType}" + ) + } + is TensorEncoding.Q4_K -> { + Q4_KBlockTensorData.fromRawBytes(storage.shape, bytes) as TensorData + } + is TensorEncoding.Q8_0 -> { + Q8_0BlockTensorData.fromRawBytes(storage.shape, bytes) as TensorData + } + else -> throw UnsupportedOperationException( + "toTensorData not supported for encoding ${storage.encoding.name}" + ) + } + } + + private fun extractBytes(storage: TensorStorage): ByteArray = when (val b = storage.buffer) { + is BufferHandle.Owned -> { + if (b.offset == 0 && b.sizeInBytes.toInt() == b.data.size) b.data + else b.data.copyOfRange(b.offset, b.offset + b.sizeInBytes.toInt()) + } + is BufferHandle.Borrowed -> { + if (b.offset == 0 && b.sizeInBytes.toInt() == b.data.size) b.data + else b.data.copyOfRange(b.offset, b.offset + b.sizeInBytes.toInt()) + } + else -> throw UnsupportedOperationException( + "Cannot extract bytes from ${b.ownership} buffer. " + + "Use a BufferResolver to read FileBacked/DeviceResident handles first." + ) + } + + private fun bytesToFloatArray(bytes: ByteArray): FloatArray { + val count = bytes.size / 4 + return FloatArray(count) { i -> + val off = i * 4 + Float.fromBits( + (bytes[off].toInt() and 0xFF) or + ((bytes[off + 1].toInt() and 0xFF) shl 8) or + ((bytes[off + 2].toInt() and 0xFF) shl 16) or + ((bytes[off + 3].toInt() and 0xFF) shl 24) + ) + } + } + + private fun bytesToIntArray(bytes: ByteArray): IntArray { + val count = bytes.size / 4 + return IntArray(count) { i -> + val off = i * 4 + (bytes[off].toInt() and 0xFF) or + ((bytes[off + 1].toInt() and 0xFF) shl 8) or + ((bytes[off + 2].toInt() and 0xFF) shl 16) or + ((bytes[off + 3].toInt() and 0xFF) shl 24) + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TurboQuantKvCacheStore.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TurboQuantKvCacheStore.kt new file mode 100644 index 00000000..048658f8 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TurboQuantKvCacheStore.kt @@ -0,0 +1,200 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantBlock +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantCodec +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantConfig +import sk.ainet.lang.tensor.ops.turboquant.RandomRotation + +/** + * KV cache store with TurboQuant compression. + * + * Compresses K/V projections on write using TurboQuant and decompresses + * on read. Supports asymmetric K/V policies (different bit budgets and + * variants for keys vs values). + * + * Each token's K/V projection per head is stored as a [TurboQuantBlock]. + * This gives fine-grained control: different layers/heads could + * potentially use different configurations (though this implementation + * uses uniform config). + */ +public class TurboQuantKvCacheStore( + private val config: KvCacheConfig, + private val keyConfig: TurboQuantConfig, + private val valueConfig: TurboQuantConfig +) : KvCacheStore { + + override val numLayers: Int get() = config.numLayers + override val numHeads: Int get() = config.numHeads + override val headDim: Int get() = config.headDim + override val maxSeqLen: Int get() = config.maxSeqLen + override val keyEncoding: TensorEncoding get() = config.keyEncoding + override val valueEncoding: TensorEncoding get() = config.valueEncoding + override val placement: Placement get() = config.placement + + private var _currentSeqLen: Int = 0 + override val currentSeqLen: Int get() = _currentSeqLen + + // Compressed storage: [layer][position] -> array of TurboQuantBlock (one per head) + private val keyBlocks: Array>> = Array(numLayers) { + Array(maxSeqLen) { arrayOfNulls(numHeads) } + } + private val valueBlocks: Array>> = Array(numLayers) { + Array(maxSeqLen) { arrayOfNulls(numHeads) } + } + + override fun appendToken(layer: Int, key: FloatArray, value: FloatArray) { + requireLayerIndex(layer) + check(_currentSeqLen < maxSeqLen) { + "KV cache is full: currentSeqLen=$_currentSeqLen, maxSeqLen=$maxSeqLen" + } + require(key.size == numHeads * headDim) { + "Key size mismatch: expected ${numHeads * headDim}, got ${key.size}" + } + require(value.size == numHeads * headDim) { + "Value size mismatch: expected ${numHeads * headDim}, got ${value.size}" + } + + val pos = _currentSeqLen + + for (h in 0 until numHeads) { + val headKey = key.copyOfRange(h * headDim, (h + 1) * headDim) + val headValue = value.copyOfRange(h * headDim, (h + 1) * headDim) + + val keySeed = RandomRotation.seedFor(layer, h, pos) + val valueSeed = keySeed xor 0x5A5A5A5A.toInt() + + keyBlocks[layer][pos][h] = TurboQuantCodec.encode( + headKey, keyConfig.copy(seed = keySeed) + ) + valueBlocks[layer][pos][h] = TurboQuantCodec.encode( + headValue, valueConfig.copy(seed = valueSeed) + ) + } + + if (layer == numLayers - 1) { + _currentSeqLen++ + } + } + + override fun readKeys(layer: Int, startPos: Int, endPos: Int): FloatArray { + return readRange(keyBlocks, layer, startPos, endPos) + } + + override fun readValues(layer: Int, startPos: Int, endPos: Int): FloatArray { + return readRange(valueBlocks, layer, startPos, endPos) + } + + override fun readKeyStorage(layer: Int, startPos: Int, endPos: Int): TensorStorage { + return toTensorStorage(readKeys(layer, startPos, endPos), endPos - startPos, keyEncoding) + } + + override fun readValueStorage(layer: Int, startPos: Int, endPos: Int): TensorStorage { + return toTensorStorage(readValues(layer, startPos, endPos), endPos - startPos, valueEncoding) + } + + override fun evict(fromPos: Int) { + require(fromPos in 0..currentSeqLen) { + "evict fromPos=$fromPos out of range [0, $currentSeqLen]" + } + for (layer in 0 until numLayers) { + for (pos in fromPos until maxSeqLen) { + for (h in 0 until numHeads) { + keyBlocks[layer][pos][h] = null + valueBlocks[layer][pos][h] = null + } + } + } + _currentSeqLen = fromPos + } + + override fun clear() { + _currentSeqLen = 0 + for (layer in 0 until numLayers) { + for (pos in 0 until maxSeqLen) { + for (h in 0 until numHeads) { + keyBlocks[layer][pos][h] = null + valueBlocks[layer][pos][h] = null + } + } + } + } + + override fun memoryReport(): KvCacheMemoryReport { + var keyBytes = 0L + var valueBytes = 0L + for (layer in 0 until numLayers) { + for (pos in 0 until _currentSeqLen) { + for (h in 0 until numHeads) { + keyBytes += keyBlocks[layer][pos][h]?.sizeInBytes ?: 0 + valueBytes += valueBlocks[layer][pos][h]?.sizeInBytes ?: 0 + } + } + } + val logicalPerLayer = numHeads.toLong() * _currentSeqLen * headDim * 4 + return KvCacheMemoryReport( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = maxSeqLen, + currentSeqLen = _currentSeqLen, + keyEncoding = keyEncoding, + valueEncoding = valueEncoding, + placement = placement, + keyPhysicalBytes = keyBytes, + valuePhysicalBytes = valueBytes, + keyLogicalBytes = numLayers * logicalPerLayer, + valueLogicalBytes = numLayers * logicalPerLayer + ) + } + + // --- Internal --- + + private fun readRange( + blocks: Array>>, + layer: Int, + startPos: Int, + endPos: Int + ): FloatArray { + requireLayerIndex(layer) + require(startPos in 0..endPos) { "Invalid range: startPos=$startPos, endPos=$endPos" } + require(endPos <= _currentSeqLen) { "endPos=$endPos exceeds currentSeqLen=$_currentSeqLen" } + + val seqLen = endPos - startPos + // Output: [numHeads, seqLen, headDim] + val result = FloatArray(numHeads * seqLen * headDim) + + for (h in 0 until numHeads) { + for (p in startPos until endPos) { + val block = blocks[layer][p][h] + ?: error("Missing block at layer=$layer, pos=$p, head=$h") + val decoded = TurboQuantCodec.decode(block) + val dstOffset = h * seqLen * headDim + (p - startPos) * headDim + decoded.copyInto(result, dstOffset) + } + } + return result + } + + private fun toTensorStorage(data: FloatArray, seqLen: Int, encoding: TensorEncoding): TensorStorage { + val bytes = ByteArray(data.size * 4) + for (i in data.indices) { + val bits = data[i].toRawBits() + bytes[i * 4] = (bits and 0xFF).toByte() + bytes[i * 4 + 1] = ((bits shr 8) and 0xFF).toByte() + bytes[i * 4 + 2] = ((bits shr 16) and 0xFF).toByte() + bytes[i * 4 + 3] = ((bits shr 24) and 0xFF).toByte() + } + return TensorStorage( + shape = Shape(numHeads, seqLen, headDim), + logicalType = LogicalDType.FLOAT32, + encoding = encoding, + buffer = BufferHandle.Owned(bytes), + placement = placement + ) + } + + private fun requireLayerIndex(layer: Int) { + require(layer in 0 until numLayers) { "Layer $layer out of range [0, $numLayers)" } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/BitPackerTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/BitPackerTest.kt new file mode 100644 index 00000000..35fc1729 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/BitPackerTest.kt @@ -0,0 +1,93 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class BitPackerTest { + + @Test + fun pack4BitRoundTrip() { + val codes = byteArrayOf(0, 1, -1, 7, -7, 3, -3, 0) + val packed = BitPacker.pack(codes, 4) + val unpacked = BitPacker.unpack(packed, codes.size, 4) + + assertTrue(codes.contentEquals(unpacked), + "4-bit round trip failed: ${codes.toList()} -> ${unpacked.toList()}") + } + + @Test + fun pack2BitRoundTrip() { + val codes = byteArrayOf(0, 1, -1, 0, 1, -1, 0, 1) + val packed = BitPacker.pack(codes, 2) + val unpacked = BitPacker.unpack(packed, codes.size, 2) + + assertTrue(codes.contentEquals(unpacked), + "2-bit round trip failed: ${codes.toList()} -> ${unpacked.toList()}") + } + + @Test + fun pack3BitRoundTrip() { + val codes = byteArrayOf(0, 1, -1, 3, -3, 2, -2, 0) + val packed = BitPacker.pack(codes, 3) + val unpacked = BitPacker.unpack(packed, codes.size, 3) + + assertTrue(codes.contentEquals(unpacked), + "3-bit round trip failed: ${codes.toList()} -> ${unpacked.toList()}") + } + + @Test + fun pack8BitRoundTrip() { + val codes = byteArrayOf(0, 127, -128, 1, -1, 64, -64, 100) + val packed = BitPacker.pack(codes, 8) + val unpacked = BitPacker.unpack(packed, codes.size, 8) + + assertTrue(codes.contentEquals(unpacked)) + } + + @Test + fun pack4BitCompression() { + val codes = ByteArray(100) + val packed = BitPacker.pack(codes, 4) + assertEquals(50, packed.size, "4-bit should be 50% size") + } + + @Test + fun pack2BitCompression() { + val codes = ByteArray(100) + val packed = BitPacker.pack(codes, 2) + assertEquals(25, packed.size, "2-bit should be 25% size") + } + + @Test + fun packedSize() { + assertEquals(50, BitPacker.packedSize(100, 4)) + assertEquals(25, BitPacker.packedSize(100, 2)) + assertEquals(100, BitPacker.packedSize(100, 8)) + assertEquals(38, BitPacker.packedSize(100, 3)) // (100*3+7)/8 + } + + @Test + fun oddCountRoundTrip() { + // Non-aligned count + val codes = byteArrayOf(1, -1, 0) + val packed4 = BitPacker.pack(codes, 4) + val unpacked4 = BitPacker.unpack(packed4, 3, 4) + assertTrue(codes.contentEquals(unpacked4)) + + val packed2 = BitPacker.pack(codes, 2) + val unpacked2 = BitPacker.unpack(packed2, 3, 2) + // 2-bit can only represent -1, 0, 1 — codes[0]=1, codes[1]=-1, codes[2]=0 all valid + assertTrue(codes.contentEquals(unpacked2)) + } + + @Test + fun pack4BitAllValues() { + // Test all valid 4-bit values: -7 to 7 + val codes = ByteArray(15) { (it - 7).toByte() } + val packed = BitPacker.pack(codes, 4) + val unpacked = BitPacker.unpack(packed, 15, 4) + assertTrue(codes.contentEquals(unpacked), + "All 4-bit values: ${codes.toList()} -> ${unpacked.toList()}") + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/RandomRotationTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/RandomRotationTest.kt new file mode 100644 index 00000000..3de2f1be --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/RandomRotationTest.kt @@ -0,0 +1,96 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import kotlin.math.abs +import kotlin.math.sqrt +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class RandomRotationTest { + + @Test + fun rotateInverseRoundTrip() { + val input = floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f) + val original = input.copyOf() + val seed = 42 + + RandomRotation.rotate(input, seed) + RandomRotation.inverseRotate(input, seed) + + for (i in original.indices) { + assertTrue(abs(original[i] - input[i]) < 1e-4f, + "Element $i: expected ${original[i]}, got ${input[i]}") + } + } + + @Test + fun rotateChangesValues() { + val input = floatArrayOf(1f, 0f, 0f, 0f) + val original = input.copyOf() + + RandomRotation.rotate(input, 42) + + // At least some values should change + var changed = false + for (i in input.indices) { + if (abs(input[i] - original[i]) > 1e-6f) changed = true + } + assertTrue(changed, "Rotation should modify the vector") + } + + @Test + fun rotateDeterministic() { + val a = floatArrayOf(1f, 2f, 3f, 4f) + val b = floatArrayOf(1f, 2f, 3f, 4f) + + RandomRotation.rotate(a, 123) + RandomRotation.rotate(b, 123) + + assertTrue(a.contentEquals(b), "Same seed should produce same rotation") + } + + @Test + fun rotatePreservesNorm() { + val input = floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f) + val normBefore = sqrt(input.sumOf { (it * it).toDouble() }).toFloat() + + RandomRotation.rotate(input, 42) + + val normAfter = sqrt(input.sumOf { (it * it).toDouble() }).toFloat() + // WHT preserves norm (orthogonal transform) + assertTrue(abs(normBefore - normAfter) < 0.1f * normBefore, + "Norm should be approximately preserved: before=$normBefore, after=$normAfter") + } + + @Test + fun seedForIsDeterministic() { + val s1 = RandomRotation.seedFor(0, 1, 2) + val s2 = RandomRotation.seedFor(0, 1, 2) + assertEquals(s1, s2) + } + + @Test + fun seedForDistribution() { + // Different inputs should produce different seeds + val seeds = mutableSetOf() + for (l in 0..3) { + for (h in 0..3) { + for (p in 0..3) { + seeds.add(RandomRotation.seedFor(l, h, p)) + } + } + } + // 64 inputs should produce at least 50 distinct seeds (well-distributed) + assertTrue(seeds.size > 50, "Seeds should be well-distributed, got ${seeds.size} unique out of 64") + } + + @Test + fun walshHadamardSmall() { + // WHT of [1, 1, 1, 1] should give [2, 0, 0, 0] (before normalization: [4, 0, 0, 0]) + // After normalization by 1/sqrt(4) = 0.5: [2, 0, 0, 0] + val input = floatArrayOf(1f, 1f, 1f, 1f) + RandomRotation.walshHadamard(input) + assertTrue(abs(input[0] - 2f) < 1e-5f, "WHT[0] should be 2, got ${input[0]}") + assertTrue(abs(input[1]) < 1e-5f, "WHT[1] should be 0, got ${input[1]}") + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/ScalarQuantizerTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/ScalarQuantizerTest.kt new file mode 100644 index 00000000..d7c17a23 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/ScalarQuantizerTest.kt @@ -0,0 +1,86 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import kotlin.math.abs +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class ScalarQuantizerTest { + + @Test + fun quantize4BitRoundTrip() { + val input = floatArrayOf(0.5f, -0.3f, 1.0f, -1.0f, 0.0f, 0.7f, -0.8f, 0.2f) + val quantized = ScalarQuantizer.quantize(input, 4) + val output = ScalarQuantizer.dequantize(quantized) + + assertEquals(input.size, output.size) + // 4-bit: 15 levels, so max error ≈ scale/2 ≈ absMax/14 + for (i in input.indices) { + assertTrue(abs(input[i] - output[i]) < 0.2f, + "Element $i: input=${input[i]}, output=${output[i]}") + } + } + + @Test + fun quantize8BitHighAccuracy() { + val input = FloatArray(64) { (it - 32).toFloat() / 32f } + val quantized = ScalarQuantizer.quantize(input, 8) + val output = ScalarQuantizer.dequantize(quantized) + + for (i in input.indices) { + assertTrue(abs(input[i] - output[i]) < 0.02f, + "8-bit should be very accurate: input=${input[i]}, output=${output[i]}") + } + } + + @Test + fun quantize2BitCoarse() { + val input = floatArrayOf(1f, -1f, 0.5f, -0.5f) + val quantized = ScalarQuantizer.quantize(input, 2) + assertEquals(2, quantized.bits) + // 2-bit: only 3 levels (-1, 0, 1) * scale + val output = ScalarQuantizer.dequantize(quantized) + assertEquals(input.size, output.size) + } + + @Test + fun quantizeAllZeros() { + val input = FloatArray(32) + val quantized = ScalarQuantizer.quantize(input, 4) + val output = ScalarQuantizer.dequantize(quantized) + + for (v in output) assertEquals(0f, v) + } + + @Test + fun quantizeMultipleGroups() { + // 64 elements = 2 groups of 32 + val input = FloatArray(64) { if (it < 32) 1f else -1f } + val quantized = ScalarQuantizer.quantize(input, 4) + assertEquals(2, quantized.numGroups) + assertEquals(64, quantized.elementCount) + } + + @Test + fun quantizeNonMultipleOfGroupSize() { + // 10 elements, not a multiple of 32 + val input = FloatArray(10) { it.toFloat() / 10f } + val quantized = ScalarQuantizer.quantize(input, 4) + val output = ScalarQuantizer.dequantize(quantized) + assertEquals(10, output.size) + } + + @Test + fun dequantizeIntoWorks() { + val input = floatArrayOf(1f, -1f, 0.5f, -0.5f) + val quantized = ScalarQuantizer.quantize(input, 4) + val output = FloatArray(10) + ScalarQuantizer.dequantizeInto(quantized.codes, quantized.scales, output, offset = 3) + + // First 3 should be 0 + assertEquals(0f, output[0]) + assertEquals(0f, output[2]) + // Elements at offset should have values + assertTrue(abs(output[3]) > 0f || abs(output[4]) > 0f) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantCodecTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantCodecTest.kt new file mode 100644 index 00000000..e7075ab6 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/turboquant/TurboQuantCodecTest.kt @@ -0,0 +1,184 @@ +package sk.ainet.lang.tensor.ops.turboquant + +import kotlin.math.abs +import kotlin.math.sqrt +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class TurboQuantCodecTest { + + private fun meanSquaredError(a: FloatArray, b: FloatArray): Float { + require(a.size == b.size) + var sum = 0.0 + for (i in a.indices) { + val diff = a[i] - b[i] + sum += diff * diff + } + return (sum / a.size).toFloat() + } + + private fun relativeError(original: FloatArray, reconstructed: FloatArray): Float { + val norm = sqrt(original.sumOf { (it * it).toDouble() }).toFloat() + if (norm == 0f) return 0f + val mse = meanSquaredError(original, reconstructed) + return sqrt(mse.toDouble()).toFloat() / norm + } + + // --- PolarOnly --- + + @Test + fun polarOnly4BitRoundTrip() { + val input = FloatArray(128) { (it - 64).toFloat() / 64f } + val config = TurboQuantConfig.polarOnly(bits = 4, seed = 42) + + val block = TurboQuantCodec.encode(input, config) + assertTrue(block.isPolarOnly) + assertNull(block.residual) + assertEquals(128, block.elementCount) + assertEquals(4, block.bits) + + val output = TurboQuantCodec.decode(block) + assertEquals(input.size, output.size) + + val re = relativeError(input, output) + assertTrue(re < 0.3f, "4-bit PolarOnly relative error should be < 30%, got ${re * 100}%") + } + + @Test + fun polarOnly8BitHighAccuracy() { + val input = FloatArray(128) { (it - 64).toFloat() / 64f } + val config = TurboQuantConfig.polarOnly(bits = 8, seed = 42) + + val block = TurboQuantCodec.encode(input, config) + val output = TurboQuantCodec.decode(block) + + val re = relativeError(input, output) + assertTrue(re < 0.05f, "8-bit should have < 5% relative error, got ${re * 100}%") + } + + @Test + fun polarOnly2BitCoarse() { + val input = FloatArray(64) { (it - 32).toFloat() / 32f } + val config = TurboQuantConfig.polarOnly(bits = 2, seed = 42) + + val block = TurboQuantCodec.encode(input, config) + val output = TurboQuantCodec.decode(block) + + assertEquals(input.size, output.size) + // 2-bit is very coarse, just verify it runs and output is finite + for (v in output) { + assertFalse(v.isNaN(), "Output should not contain NaN") + assertFalse(v.isInfinite(), "Output should not contain Infinity") + } + } + + @Test + fun polarOnly3Bit() { + val input = FloatArray(128) { (it - 64).toFloat() / 64f } + val config = TurboQuantConfig.polarOnly(bits = 3, seed = 42) + + val block = TurboQuantCodec.encode(input, config) + val output = TurboQuantCodec.decode(block) + + val re = relativeError(input, output) + assertTrue(re < 0.5f, "3-bit relative error should be < 50%, got ${re * 100}%") + } + + // --- PolarPlusQjl --- + + @Test + fun polarPlusQjl4BitRoundTrip() { + val input = FloatArray(128) { (it - 64).toFloat() / 64f } + val config = TurboQuantConfig.polarPlusQjl(bits = 4, residualBits = 1, seed = 42) + + val block = TurboQuantCodec.encode(input, config) + assertFalse(block.isPolarOnly) + assertNotNull(block.residual) + + val output = TurboQuantCodec.decode(block) + assertEquals(input.size, output.size) + + // With QJL, error should not be worse than without + val re = relativeError(input, output) + assertTrue(re < 0.4f, "4-bit+QJL relative error should be reasonable, got ${re * 100}%") + } + + @Test + fun polarPlusQjl2BitResidual() { + val input = FloatArray(64) { (it - 32).toFloat() / 32f } + val config = TurboQuantConfig.polarPlusQjl(bits = 4, residualBits = 2, seed = 42) + + val block = TurboQuantCodec.encode(input, config) + assertNotNull(block.residual) + assertEquals(2, block.residual!!.residualBits) + + val output = TurboQuantCodec.decode(block) + assertEquals(input.size, output.size) + } + + // --- Compression --- + + @Test + fun encodedSizeSmaller() { + val input = FloatArray(128) + val config = TurboQuantConfig.polarOnly(bits = 4, seed = 0) + val block = TurboQuantCodec.encode(input, config) + + val originalSize = 128 * 4 // 512 bytes as FP32 + assertTrue(block.sizeInBytes < originalSize, + "Encoded size (${block.sizeInBytes}) should be < original ($originalSize)") + } + + // --- Determinism --- + + @Test + fun encodingIsDeterministic() { + val input = FloatArray(64) { it.toFloat() } + val config = TurboQuantConfig.polarOnly(bits = 4, seed = 42) + + val block1 = TurboQuantCodec.encode(input, config) + val block2 = TurboQuantCodec.encode(input, config) + + assertEquals(block1, block2, "Same input + config should produce identical blocks") + } + + // --- Zero input --- + + @Test + fun zeroInputRoundTrip() { + val input = FloatArray(64) + val config = TurboQuantConfig.polarOnly(bits = 4, seed = 42) + + val block = TurboQuantCodec.encode(input, config) + val output = TurboQuantCodec.decode(block) + + for (v in output) { + assertTrue(abs(v) < 1e-5f, "Zero input should reconstruct to ~zero, got $v") + } + } + + // --- Config --- + + @Test + fun configValidation() { + // Valid configs + TurboQuantConfig.polarOnly(bits = 2) + TurboQuantConfig.polarOnly(bits = 3) + TurboQuantConfig.polarOnly(bits = 4) + TurboQuantConfig.polarOnly(bits = 8) + TurboQuantConfig.polarPlusQjl(bits = 4, residualBits = 1) + TurboQuantConfig.polarPlusQjl(bits = 4, residualBits = 4) + } + + @Test + fun encodedSizeComputation() { + val config = TurboQuantConfig.polarOnly(bits = 4) + val size = TurboQuantCodec.encodedSize(128, config) + assertTrue(size > 0) + assertTrue(size < 128 * 4) // Less than FP32 + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/AcceptanceCriteriaTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/AcceptanceCriteriaTest.kt new file mode 100644 index 00000000..2d712a98 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/AcceptanceCriteriaTest.kt @@ -0,0 +1,249 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData +import sk.ainet.lang.tensor.data.Q4_KBlockTensorData +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData +import sk.ainet.lang.types.FP32 +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue + +/** + * End-to-end acceptance criteria tests for the Memory Architecture PRD. + * + * AC1: Large GGUF can be parsed without whole-file heap loading + * → Tested via StreamingGGUFReader integration tests (requires file I/O, in gguf module) + * + * AC2: Tensors stay borrowed/mapped/packed after loading + * AC3: Tensor views remain zero-copy, copy operations are explicit + * AC4: Quantized tensors exist as packed layouts end-to-end + * AC5: Every tensor reports encoding, ownership, placement, logical size, physical size + * AC6: Runtime distinguishes immutable weights from mutable runtime buffers + */ +class AcceptanceCriteriaTest { + + // --- AC2: Tensors stay borrowed/mapped/packed after loading --- + + @Test + fun ac2_borrowedStorageSurvivesConversion() { + val rawQ4K = ByteArray(144) + val packed = Q4_KBlockTensorData.fromRawBytes(Shape(256), rawQ4K) + val storage = TensorStorageFactory.fromTensorData(packed) + + assertEquals(Ownership.BORROWED, storage.ownership) + assertEquals(TensorEncoding.Q4_K, storage.encoding) + assertFalse(storage.isMutable) + } + + @Test + fun ac2_fileBackedStoragePreservesPlacement() { + val storage = TensorStorageFactory.fileBacked( + shape = Shape(1024, 768), + logicalType = LogicalDType.FLOAT16, + encoding = TensorEncoding.Dense(2), + path = "/model/weights.bin", + fileOffset = 0, + sizeInBytes = 1024L * 768 * 2 + ) + + assertTrue(storage.isFileBacked) + assertFalse(storage.isMutable) + assertEquals(MemoryDomain.MMAP_FILE, storage.placement.domain) + assertEquals(Residency.PERSISTENT, storage.placement.residency) + } + + // --- AC3: Tensor views zero-copy, copies explicit --- + + @Test + fun ac3_borrowedConstructorDoesNotCopy() { + val original = floatArrayOf(1f, 2f, 3f) + val storage = TensorStorageFactory.borrowFloatArray(Shape(3), original) + assertEquals(Ownership.BORROWED, storage.ownership) + } + + @Test + fun ac3_ownedConstructorCopies() { + val original = floatArrayOf(1f, 2f, 3f) + val storage = TensorStorageFactory.fromFloatArray(Shape(3), original) + assertEquals(Ownership.OWNED, storage.ownership) + } + + @Test + fun ac3_aliasedSliceSharesParentBuffer() { + val parent = BufferHandle.Owned(ByteArray(1000)) + val alias = BufferHandle.Aliased(parent, byteOffset = 100, sizeInBytes = 200) + + assertEquals(Ownership.ALIASED, alias.ownership) + assertTrue(alias.isMutable) // inherits from parent + assertEquals(200L, alias.sizeInBytes) + } + + // --- AC4: Quantized tensors as packed layouts end-to-end --- + + @Test + fun ac4_q4kStaysPackedEndToEnd() { + // 1. Create from raw bytes (simulating file load) + val rawBytes = ByteArray(144) // 1 Q4_K block + val packed = Q4_KBlockTensorData.fromRawBytes(Shape(256), rawBytes) + + // 2. Verify it's still packed (not densified) + assertTrue(packed is PackedBlockStorage) + assertEquals(TensorEncoding.Q4_K, (packed as PackedBlockStorage).encoding) + assertEquals(144L, packed.physicalBytes) + + // 3. Convert to TensorStorage descriptor + val storage = packed.toTensorStorage() + assertEquals(TensorEncoding.Q4_K, storage.encoding) + assertEquals(144L, storage.physicalBytes) + assertEquals(1024L, storage.logicalBytes) // logical FP32: 256 * 4 + + // 4. Physical bytes << logical bytes (compression working) + assertTrue(storage.physicalBytes < storage.logicalBytes) + } + + @Test + fun ac4_q80StaysPackedEndToEnd() { + val rawBytes = ByteArray(34 * 4) // 4 Q8_0 blocks = 128 elements + val packed = Q8_0BlockTensorData.fromRawBytes(Shape(128), rawBytes) + + assertTrue(packed is PackedBlockStorage) + assertEquals(TensorEncoding.Q8_0, (packed as PackedBlockStorage).encoding) + assertEquals(136L, packed.physicalBytes) // 4 * 34 + } + + // --- AC5: Every tensor reports encoding, ownership, placement, sizes --- + + @Test + fun ac5_denseFloatReportsAllFields() { + val data = floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f) + val td = DenseFloatArrayTensorData(Shape(2, 3), data) + val storage = TensorStorageFactory.fromTensorData(td) + val report = storage.memoryReport() + + assertEquals(LogicalDType.FLOAT32, report.logicalType) + assertEquals("Dense(4B)", report.encoding.name) + assertEquals(Ownership.OWNED, report.ownership) + assertEquals(24L, report.logicalBytes) + assertEquals(24L, report.physicalBytes) + assertFalse(report.isFileBacked) + assertFalse(report.isAlias) + assertTrue(report.isMutable) + } + + @Test + fun ac5_packedQ4KReportsAllFields() { + val storage = TensorStorage( + shape = Shape(512), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + buffer = BufferHandle.Borrowed(ByteArray(288)), // 2 blocks + placement = Placement.CPU_HEAP + ) + val report = storage.memoryReport() + + assertEquals(LogicalDType.FLOAT32, report.logicalType) + assertEquals("Q4_K", report.encoding.name) + assertEquals(Ownership.BORROWED, report.ownership) + assertEquals(DeviceKind.CPU, report.placement.device) + assertEquals(MemoryDomain.HOST_HEAP, report.placement.domain) + assertEquals(2048L, report.logicalBytes) // 512 * 4 + assertEquals(288L, report.physicalBytes) // 2 Q4_K blocks + assertTrue(report.compressionRatio > 7.0) + } + + @Test + fun ac5_fileBackedReportsAllFields() { + val storage = TensorStorage( + shape = Shape(1000), + logicalType = LogicalDType.FLOAT16, + encoding = TensorEncoding.Dense(2), + buffer = BufferHandle.FileBacked("/model.bin", 4096, 2000), + placement = Placement.MMAP_WEIGHTS + ) + val report = storage.memoryReport() + + assertTrue(report.isFileBacked) + assertEquals(Ownership.FILE_BACKED, report.ownership) + assertEquals(MemoryDomain.MMAP_FILE, report.placement.domain) + assertEquals(Residency.PERSISTENT, report.placement.residency) + assertFalse(report.isMutable) + } + + // --- AC6: Distinguish immutable weights from mutable runtime buffers --- + + @Test + fun ac6_weightsAreImmutableAndPersistent() { + val weights = TensorStorage( + shape = Shape(768, 768), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.FileBacked("/model.bin", 0, 768L * 768 * 4), + placement = Placement.MMAP_WEIGHTS + ) + + assertFalse(weights.isMutable) + assertEquals(Residency.PERSISTENT, weights.placement.residency) + assertTrue(weights.isFileBacked) + } + + @Test + fun ac6_activationsAreMutableAndTransient() { + val activations = TensorStorage( + shape = Shape(32, 768), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(ByteArray(32 * 768 * 4)), + placement = Placement.CPU_HEAP + ) + + assertTrue(activations.isMutable) + assertEquals(Residency.TRANSIENT, activations.placement.residency) + assertFalse(activations.isFileBacked) + } + + @Test + fun ac6_plannerDistinguishesWeightsFromActivations() { + val planner = MemoryPlanner() + + val weightPlacement = planner.suggestWeightPlacement(isFileBacked = true) + assertEquals(MemoryDomain.MMAP_FILE, weightPlacement.domain) + assertEquals(Residency.PERSISTENT, weightPlacement.residency) + + val activationPlacement = planner.suggestActivationPlacement() + assertEquals(MemoryDomain.HOST_HEAP, activationPlacement.domain) + assertEquals(Residency.TRANSIENT, activationPlacement.residency) + + assertNotEquals(weightPlacement, activationPlacement) + } + + // --- Aggregate observability --- + + @Test + fun memoryTrackerDetectsUnexpectedCopies() { + val tracker = MemoryTracker() + + // Load two tensors — one borrowed, one owned (copy) + tracker.record("borrowed_weight", TensorStorage( + shape = Shape(100), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Borrowed(ByteArray(400)) + )) + tracker.record("copied_activation", TensorStorage( + shape = Shape(100), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(ByteArray(400)) + )) + tracker.recordCopy("copied_activation", 400) + + val report = tracker.report() + assertEquals(1L, report.copyCount) + assertEquals(400L, report.copyBytes) + assertEquals(1, report.borrowedCount) + assertEquals(1, report.ownedCount) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/ActiveMemoryTrackerTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/ActiveMemoryTrackerTest.kt new file mode 100644 index 00000000..f8da35df --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/ActiveMemoryTrackerTest.kt @@ -0,0 +1,77 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.DenseTensorDataFactory +import sk.ainet.lang.types.FP32 +import kotlin.test.AfterTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +class ActiveMemoryTrackerTest { + + @AfterTest + fun teardown() { + ActiveMemoryTracker.current = null + } + + @Test + fun recordCopy_withActiveTracker_capturesCopy() { + val tracker = MemoryTracker() + ActiveMemoryTracker.current = tracker + + ActiveMemoryTracker.recordCopy("test_source", 100) + + val report = tracker.report() + assertEquals(1L, report.copyCount) + assertEquals(100L, report.copyBytes) + } + + @Test + fun recordCopy_withNullTracker_noOp() { + ActiveMemoryTracker.current = null + // Should not crash + ActiveMemoryTracker.recordCopy("test", 50) + } + + @Test + fun trackerCaptures_DenseTensorDataFactory_copy() { + val tracker = MemoryTracker() + ActiveMemoryTracker.current = tracker + + val factory = DenseTensorDataFactory() + factory.fromFloatArray(Shape(10), FP32::class, FloatArray(10)) + + val report = tracker.report() + // fromFloatArray calls createFloatTensorData which records a copy + assertEquals(1L, report.copyCount) + assertEquals(40L, report.copyBytes) // 10 floats * 4 bytes + } + + @Test + fun multipleCopies_accumulate() { + val tracker = MemoryTracker() + ActiveMemoryTracker.current = tracker + + ActiveMemoryTracker.recordCopy("a", 100) + ActiveMemoryTracker.recordCopy("b", 200) + ActiveMemoryTracker.recordCopy("c", 300) + + val report = tracker.report() + assertEquals(3L, report.copyCount) + assertEquals(600L, report.copyBytes) + } + + @Test + fun clearResets_afterTracking() { + val tracker = MemoryTracker() + ActiveMemoryTracker.current = tracker + + ActiveMemoryTracker.recordCopy("x", 50) + tracker.clear() + + val report = tracker.report() + assertEquals(0L, report.copyCount) + assertEquals(0L, report.copyBytes) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/BufferHandleFactoryTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/BufferHandleFactoryTest.kt new file mode 100644 index 00000000..aabd2d23 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/BufferHandleFactoryTest.kt @@ -0,0 +1,215 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData +import sk.ainet.lang.tensor.data.DenseIntArrayTensorData +import sk.ainet.lang.tensor.data.Q4_KBlockTensorData +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class BufferHandleFactoryTest { + + @Test + fun ownedFromByteArrayCopiesData() { + val original = byteArrayOf(1, 2, 3, 4) + val handle = BufferHandleFactory.owned(original) + assertEquals(4L, handle.sizeInBytes) + assertTrue(handle.isMutable) + // Modifying original should not affect the handle + original[0] = 99 + assertEquals(1, handle.data[0]) + } + + @Test + fun ownedFromFloatArrayConvertsToBytes() { + val floats = floatArrayOf(1.0f, 2.0f) + val handle = BufferHandleFactory.owned(floats) + assertEquals(8L, handle.sizeInBytes) // 2 floats * 4 bytes + // Verify first float bytes (little-endian IEEE 754 for 1.0f = 0x3F800000) + val bits = (handle.data[3].toInt() and 0xFF shl 24) or + (handle.data[2].toInt() and 0xFF shl 16) or + (handle.data[1].toInt() and 0xFF shl 8) or + (handle.data[0].toInt() and 0xFF) + assertEquals(1.0f, Float.fromBits(bits)) + } + + @Test + fun ownedFromIntArrayConvertsToBytes() { + val ints = intArrayOf(42, 100) + val handle = BufferHandleFactory.owned(ints) + assertEquals(8L, handle.sizeInBytes) + // Verify first int (little-endian: 42 = 0x0000002A) + val v = (handle.data[3].toInt() and 0xFF shl 24) or + (handle.data[2].toInt() and 0xFF shl 16) or + (handle.data[1].toInt() and 0xFF shl 8) or + (handle.data[0].toInt() and 0xFF) + assertEquals(42, v) + } + + @Test + fun borrowSharesArray() { + val data = byteArrayOf(10, 20, 30) + val handle = BufferHandleFactory.borrow(data) + assertEquals(3L, handle.sizeInBytes) + assertFalse(handle.isMutable) + // Same backing array + data[0] = 99 + assertEquals(99, handle.data[0]) + } + + @Test + fun borrowWithOffsetAndLength() { + val data = byteArrayOf(0, 1, 2, 3, 4, 5) + val handle = BufferHandleFactory.borrow(data, offset = 2, length = 3) + assertEquals(3L, handle.sizeInBytes) + assertEquals(2, handle.offset) + } + + @Test + fun sliceCreatesAliasedHandle() { + val parent = BufferHandleFactory.owned(ByteArray(100)) + val alias = BufferHandleFactory.slice(parent, byteOffset = 20, sizeInBytes = 30) + assertEquals(30L, alias.sizeInBytes) + assertEquals(20L, alias.byteOffset) + assertEquals(Ownership.ALIASED, alias.ownership) + assertTrue(alias.isMutable) // inherits from parent + } + + @Test + fun fileBackedCreation() { + val handle = BufferHandleFactory.fileBacked("/weights.bin", offset = 1024, size = 4096) + assertEquals(4096L, handle.sizeInBytes) + assertEquals("/weights.bin", handle.path) + assertEquals(1024L, handle.fileOffset) + assertFalse(handle.isMutable) + } +} + +class TensorStorageFactoryTest { + + @Test + fun fromFloatArrayCreatesDenseStorage() { + val shape = Shape(2, 3) + val data = FloatArray(6) { it.toFloat() } + val storage = TensorStorageFactory.fromFloatArray(shape, data) + + assertEquals(shape, storage.shape) + assertEquals(LogicalDType.FLOAT32, storage.logicalType) + assertEquals(TensorEncoding.Dense(4), storage.encoding) + assertEquals(Ownership.OWNED, storage.ownership) + assertEquals(24L, storage.logicalBytes) + assertEquals(24L, storage.physicalBytes) + assertTrue(storage.isMutable) + } + + @Test + fun fromIntArrayCreatesDenseStorage() { + val shape = Shape(4) + val data = intArrayOf(1, 2, 3, 4) + val storage = TensorStorageFactory.fromIntArray(shape, data) + + assertEquals(LogicalDType.INT32, storage.logicalType) + assertEquals(Ownership.OWNED, storage.ownership) + assertEquals(16L, storage.physicalBytes) // 4 * 4 + } + + @Test + fun fromRawBytesCreatesBorrowedStorage() { + val data = ByteArray(144) + val storage = TensorStorageFactory.fromRawBytes( + shape = Shape(256), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + data = data + ) + + assertEquals(TensorEncoding.Q4_K, storage.encoding) + assertEquals(Ownership.BORROWED, storage.ownership) + assertEquals(144L, storage.physicalBytes) + assertEquals(1024L, storage.logicalBytes) // 256 * 4 + assertFalse(storage.isMutable) + } + + @Test + fun fileBackedCreatesImmutableStorage() { + val storage = TensorStorageFactory.fileBacked( + shape = Shape(512, 512), + logicalType = LogicalDType.FLOAT16, + encoding = TensorEncoding.Dense(2), + path = "/model.bin", + fileOffset = 0, + sizeInBytes = 512L * 512 * 2 + ) + + assertTrue(storage.isFileBacked) + assertFalse(storage.isMutable) + assertEquals(Placement.MMAP_WEIGHTS, storage.placement) + assertEquals(MemoryDomain.MMAP_FILE, storage.placement.domain) + } + + @Test + fun fromTensorDataBridgesFloatTensorData() { + val tensorData = DenseFloatArrayTensorData(Shape(3), floatArrayOf(1f, 2f, 3f)) + val storage = TensorStorageFactory.fromTensorData(tensorData) + + assertEquals(LogicalDType.FLOAT32, storage.logicalType) + assertEquals(TensorEncoding.Dense(4), storage.encoding) + assertEquals(3L, storage.elementCount) + assertEquals(12L, storage.physicalBytes) + } + + @Test + fun fromTensorDataBridgesIntTensorData() { + val tensorData = DenseIntArrayTensorData(Shape(2), intArrayOf(10, 20)) + val storage = TensorStorageFactory.fromTensorData(tensorData) + + assertEquals(LogicalDType.INT32, storage.logicalType) + assertEquals(TensorEncoding.Dense(4), storage.encoding) + } + + @Test + fun fromTensorDataBridgesQ4KTensorData() { + val packedData = ByteArray(144) // 1 block of Q4_K + val tensorData = Q4_KBlockTensorData.fromRawBytes(Shape(256), packedData) + val storage = TensorStorageFactory.fromTensorData(tensorData) + + assertEquals(LogicalDType.FLOAT32, storage.logicalType) + assertEquals(TensorEncoding.Q4_K, storage.encoding) + assertEquals(Ownership.BORROWED, storage.ownership) + assertEquals(144L, storage.physicalBytes) + } + + @Test + fun fromTensorDataBridgesQ80TensorData() { + val packedData = ByteArray(34) // 1 block of Q8_0 + val tensorData = Q8_0BlockTensorData.fromRawBytes(Shape(32), packedData) + val storage = TensorStorageFactory.fromTensorData(tensorData) + + assertEquals(LogicalDType.FLOAT32, storage.logicalType) + assertEquals(TensorEncoding.Q8_0, storage.encoding) + assertEquals(Ownership.BORROWED, storage.ownership) + assertEquals(34L, storage.physicalBytes) + } + + @Test + fun memoryReportFromFactory() { + val storage = TensorStorageFactory.fromRawBytes( + shape = Shape(256), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + data = ByteArray(144) + ) + val report = storage.memoryReport() + + assertEquals(1024L, report.logicalBytes) + assertEquals(144L, report.physicalBytes) + assertTrue(report.compressionRatio > 7.0) // ~7.1x compression + assertEquals(Ownership.BORROWED, report.ownership) + assertFalse(report.isFileBacked) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/CompressedKvAttentionTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/CompressedKvAttentionTest.kt new file mode 100644 index 00000000..6b232908 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/CompressedKvAttentionTest.kt @@ -0,0 +1,116 @@ +package sk.ainet.lang.tensor.storage + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Tests for [CompressedKvAttention] — the bridge between KvCacheStore and SDPA. + */ +class CompressedKvAttentionTest { + + private fun createBridge( + numLayers: Int = 1, + numHeads: Int = 2, + headDim: Int = 4, + maxSeqLen: Int = 8, + keyEncoding: TensorEncoding = TensorEncoding.Dense(4), + valueEncoding: TensorEncoding = TensorEncoding.Dense(4), + strategy: CompressedKvAttention.DequantStrategy = CompressedKvAttention.DequantStrategy.FULL_TILE + ): CompressedKvAttention { + val config = KvCacheConfig( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = maxSeqLen, + keyEncoding = keyEncoding, + valueEncoding = valueEncoding + ) + return CompressedKvAttention(DefaultKvCacheStore(config), strategy) + } + + @Test + fun storeAndLoadRoundTrip() { + val bridge = createBridge() + val key = FloatArray(2 * 4) { it.toFloat() } + val value = FloatArray(2 * 4) { (it + 100).toFloat() } + + bridge.storeKeyValue(0, key, value) + + val loadedKeys = bridge.loadKeysForAttention(0) + val loadedValues = bridge.loadValuesForAttention(0) + + assertEquals(2 * 1 * 4, loadedKeys.size) + assertEquals(0f, loadedKeys[0]) + assertEquals(7f, loadedKeys[7]) + + assertEquals(100f, loadedValues[0]) + assertEquals(107f, loadedValues[7]) + } + + @Test + fun loadWithSubRange() { + val bridge = createBridge(numHeads = 1, headDim = 2) + + bridge.storeKeyValue(0, floatArrayOf(1f, 2f), floatArrayOf(10f, 20f)) + bridge.storeKeyValue(0, floatArrayOf(3f, 4f), floatArrayOf(30f, 40f)) + bridge.storeKeyValue(0, floatArrayOf(5f, 6f), floatArrayOf(50f, 60f)) + + // Read only position 1 + val keys = bridge.loadKeysForAttention(0, startPos = 1, endPos = 2) + assertEquals(2, keys.size) + assertEquals(3f, keys[0]) + assertEquals(4f, keys[1]) + } + + @Test + fun rawStorageReturnsTensorStorage() { + val bridge = createBridge() + bridge.storeKeyValue(0, FloatArray(8), FloatArray(8)) + + val keyStorage = bridge.loadKeyStorageRaw(0) + assertEquals(LogicalDType.FLOAT32, keyStorage.logicalType) + assertEquals(Ownership.OWNED, keyStorage.ownership) + + val valueStorage = bridge.loadValueStorageRaw(0) + assertEquals(LogicalDType.FLOAT32, valueStorage.logicalType) + } + + @Test + fun isCompressedDetectsEncoding() { + val denseBridge = createBridge() + assertFalse(denseBridge.isKeyCompressed) + assertFalse(denseBridge.isValueCompressed) + + val compressedBridge = createBridge( + keyEncoding = TensorEncoding.Q8_0, + valueEncoding = TensorEncoding.Q4_K + ) + assertTrue(compressedBridge.isKeyCompressed) + assertTrue(compressedBridge.isValueCompressed) + } + + @Test + fun asymmetricCompression() { + val bridge = createBridge( + keyEncoding = TensorEncoding.Q8_0, + valueEncoding = TensorEncoding.Dense(4) + ) + assertTrue(bridge.isKeyCompressed) + assertFalse(bridge.isValueCompressed) + } + + @Test + fun rawStorageStrategyFallsBackToFloat() { + val bridge = createBridge( + strategy = CompressedKvAttention.DequantStrategy.RAW_STORAGE + ) + bridge.storeKeyValue(0, FloatArray(8) { it.toFloat() }, FloatArray(8)) + + // RAW_STORAGE still returns float for default implementation + val keys = bridge.loadKeysForAttention(0) + assertEquals(8, keys.size) + assertEquals(0f, keys[0]) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/ExplicitCopyApiTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/ExplicitCopyApiTest.kt new file mode 100644 index 00000000..5448dd3c --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/ExplicitCopyApiTest.kt @@ -0,0 +1,102 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData +import sk.ainet.lang.tensor.data.DenseIntArrayTensorData +import sk.ainet.lang.tensor.data.DenseTensorDataFactory +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertSame +import kotlin.test.assertNotSame +import kotlin.test.assertTrue + +class ExplicitCopyApiTest { + + private val factory = DenseTensorDataFactory() + + // --- wrapFloatArray (zero-copy) --- + + @Test + fun wrapFloatArraySharesBuffer() { + val original = floatArrayOf(1f, 2f, 3f, 4f) + val wrapped = factory.wrapFloatArray(Shape(4), FP32::class, original) + + assertTrue(wrapped is FloatArrayTensorData<*>) + val floatData = wrapped as FloatArrayTensorData<*> + // The buffer IS the same array (zero-copy) + assertSame(original, floatData.buffer) + } + + @Test + fun fromFloatArrayCopiesBuffer() { + val original = floatArrayOf(1f, 2f, 3f, 4f) + val copied = factory.fromFloatArray(Shape(4), FP32::class, original) + + assertTrue(copied is FloatArrayTensorData<*>) + val floatData = copied as FloatArrayTensorData<*> + // The buffer is a DIFFERENT array (copy) + assertNotSame(original, floatData.buffer) + // But same contents + assertEquals(original.toList(), floatData.buffer.toList()) + } + + @Test + fun wrapFloatArrayMutationsVisibleThroughTensorData() { + val original = floatArrayOf(10f, 20f, 30f) + val wrapped = factory.wrapFloatArray(Shape(3), FP32::class, original) + + // Mutate original + original[0] = 99f + // Change is visible through the wrapped tensor data + assertEquals(99f, wrapped[0]) + } + + // --- wrapIntArray (zero-copy) --- + + @Test + fun wrapIntArraySharesBuffer() { + val original = intArrayOf(10, 20, 30) + val wrapped = factory.wrapIntArray(Shape(3), Int32::class, original) + + val intData = wrapped as sk.ainet.lang.tensor.data.IntArrayTensorData<*> + assertSame(original, intData.buffer) + } + + @Test + fun fromIntArrayCopiesBuffer() { + val original = intArrayOf(10, 20, 30) + val copied = factory.fromIntArray(Shape(3), Int32::class, original) + + val intData = copied as sk.ainet.lang.tensor.data.IntArrayTensorData<*> + assertNotSame(original, intData.buffer) + } + + // --- TensorStorage bridge with borrowed vs owned --- + + @Test + fun tensorStorageFromBorrowedRawBytes() { + val rawData = ByteArray(144) // Q4_K block + val storage = TensorStorageFactory.fromRawBytes( + shape = Shape(256), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + data = rawData + ) + assertEquals(Ownership.BORROWED, storage.ownership) + } + + @Test + fun tensorStorageFromOwnedRawBytes() { + val rawData = ByteArray(144) + val storage = TensorStorageFactory.fromRawBytesOwned( + shape = Shape(256), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + data = rawData + ) + assertEquals(Ownership.OWNED, storage.ownership) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/KvCacheFactoryTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/KvCacheFactoryTest.kt new file mode 100644 index 00000000..69b1346a --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/KvCacheFactoryTest.kt @@ -0,0 +1,142 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantPresets +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertIs +import kotlin.test.assertTrue + +/** + * Tests for KvCacheStore factory methods, TurboQuantPresets.forModel(), + * and KvCacheAnnotationResolver. + */ +class KvCacheFactoryTest { + + // --- KvCacheStore.dense() --- + + @Test + fun denseFactoryCreatesDenseStore() { + val cache = KvCacheStore.dense(numLayers = 2, numHeads = 4, headDim = 64, maxSeqLen = 128) + assertIs(cache) + assertEquals(2, cache.numLayers) + assertEquals(4, cache.numHeads) + assertEquals(64, cache.headDim) + assertEquals(128, cache.maxSeqLen) + } + + // --- KvCacheStore.turboQuant(preset) --- + + @Test + fun turboQuantPresetBalanced() { + val cache = KvCacheStore.turboQuant("balanced", 2, 4, 64, 128) + assertIs(cache) + assertEquals(2, cache.numLayers) + assertIs(cache.keyEncoding) + assertIs(cache.valueEncoding) + assertEquals(4, (cache.keyEncoding as TensorEncoding.TurboQuantPolar).bitsPerElement) + } + + @Test + fun turboQuantPresetSafeLowbit() { + val cache = KvCacheStore.turboQuant("safe-lowbit", 2, 4, 64, 128) + assertIs(cache) + assertEquals(TensorEncoding.Q8_0, cache.keyEncoding) + assertIs(cache.valueEncoding) + } + + @Test + fun turboQuantPresetExperimentalMax() { + val cache = KvCacheStore.turboQuant("experimental-max", 2, 4, 64, 128) + assertIs(cache) + assertEquals(3, (cache.keyEncoding as TensorEncoding.TurboQuantPolar).bitsPerElement) + } + + @Test + fun turboQuantUnknownPresetThrows() { + assertFailsWith { + KvCacheStore.turboQuant("nonexistent", 2, 4, 64, 128) + } + } + + // --- KvCacheStore.turboQuant(custom) --- + + @Test + fun turboQuantCustomBits() { + val cache = KvCacheStore.turboQuant( + numLayers = 2, numHeads = 4, headDim = 64, maxSeqLen = 128, + keyBits = 8, valueBits = 3 + ) + assertIs(cache) + assertEquals(8, (cache.keyEncoding as TensorEncoding.TurboQuantPolar).bitsPerElement) + assertEquals(3, (cache.valueEncoding as TensorEncoding.TurboQuantPolar).bitsPerElement) + } + + // --- KvCacheStore.fromPreset() --- + + @Test + fun fromPresetCreatesCorrectCache() { + val preset = TurboQuantPresets.balanced(2, 4, 64, 128) + val cache = KvCacheStore.fromPreset(preset) + assertIs(cache) + assertEquals(2, cache.numLayers) + } + + // --- TurboQuantPresets.forModel() --- + + @Test + fun forModelBalanced() { + val preset = TurboQuantPresets.forModel("balanced", 32, 32, 128, 4096) + assertEquals("balanced", preset.name) + assertEquals(32, preset.cacheConfig.numLayers) + assertEquals(4096, preset.cacheConfig.maxSeqLen) + } + + @Test + fun forModelUnknownThrows() { + assertFailsWith { + TurboQuantPresets.forModel("invalid", 2, 4, 64, 128) + } + } + + // --- KvCacheAnnotationResolver --- + + @Test + fun resolvePresetString() { + val cache = KvCacheAnnotationResolver.resolve("balanced", 2, 4, 64, 128) + assertIs(cache) + } + + @Test + fun resolveDensePreset() { + val cache = KvCacheAnnotationResolver.resolve("dense", 2, 4, 64, 128) + assertIs(cache) + } + + @Test + fun resolveNonePreset() { + val cache = KvCacheAnnotationResolver.resolve("none", 2, 4, 64, 128) + assertIs(cache) + } + + // --- End-to-end: factory → append → read --- + + @Test + fun factoryCreatedCacheWorksEndToEnd() { + val cache = KvCacheStore.turboQuant("balanced", 1, 2, 64, 16) + val bridge = CompressedKvAttention(cache) + + val key = FloatArray(2 * 64) { it.toFloat() / 128f } + val value = FloatArray(2 * 64) { -it.toFloat() / 128f } + + bridge.storeKeyValue(0, key, value) + assertEquals(1, cache.currentSeqLen) + + val readK = bridge.loadKeysForAttention(0) + assertEquals(2 * 1 * 64, readK.size) + + val report = cache.memoryReport() + assertTrue(report.compressionRatio > 1.0, + "TurboQuant should compress: ratio=${report.compressionRatio}") + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/KvCacheStoreTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/KvCacheStoreTest.kt new file mode 100644 index 00000000..24f84928 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/KvCacheStoreTest.kt @@ -0,0 +1,238 @@ +package sk.ainet.lang.tensor.storage + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Tests for [KvCacheStore] contract and [DefaultKvCacheStore] implementation. + */ +class KvCacheStoreTest { + + private fun createStore( + numLayers: Int = 2, + numHeads: Int = 4, + headDim: Int = 8, + maxSeqLen: Int = 16 + ): DefaultKvCacheStore = DefaultKvCacheStore( + KvCacheConfig(numLayers, numHeads, headDim, maxSeqLen) + ) + + // --- Append and read --- + + @Test + fun appendAndReadSingleToken() { + val store = createStore(numLayers = 1, numHeads = 2, headDim = 4, maxSeqLen = 8) + val key = FloatArray(2 * 4) { it.toFloat() } // [0..7] + val value = FloatArray(2 * 4) { (it + 10).toFloat() } // [10..17] + + store.appendToken(0, key, value) + assertEquals(1, store.currentSeqLen) + + val readK = store.readKeys(0) + val readV = store.readValues(0) + + // Shape: [numHeads=2, seqLen=1, headDim=4] + assertEquals(2 * 1 * 4, readK.size) + assertEquals(2 * 1 * 4, readV.size) + + // Head 0: [0, 1, 2, 3] + assertEquals(0f, readK[0]) + assertEquals(1f, readK[1]) + assertEquals(2f, readK[2]) + assertEquals(3f, readK[3]) + + // Head 1: [4, 5, 6, 7] + assertEquals(4f, readK[4]) + assertEquals(5f, readK[5]) + + // Values: head 0 starts at 10 + assertEquals(10f, readV[0]) + assertEquals(13f, readV[3]) + } + + @Test + fun appendMultipleTokens() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 2, maxSeqLen = 4) + + // Token 0 + store.appendToken(0, floatArrayOf(1f, 2f), floatArrayOf(10f, 20f)) + // Token 1 + store.appendToken(0, floatArrayOf(3f, 4f), floatArrayOf(30f, 40f)) + + assertEquals(2, store.currentSeqLen) + + val keys = store.readKeys(0) + // [numHeads=1, seqLen=2, headDim=2] = [1, 2, 3, 4] + assertEquals(4, keys.size) + assertEquals(1f, keys[0]) + assertEquals(2f, keys[1]) + assertEquals(3f, keys[2]) + assertEquals(4f, keys[3]) + } + + @Test + fun appendMultipleLayers() { + val store = createStore(numLayers = 2, numHeads = 1, headDim = 2, maxSeqLen = 4) + + // Layer 0 then Layer 1 for token 0 + store.appendToken(0, floatArrayOf(1f, 2f), floatArrayOf(10f, 20f)) + store.appendToken(1, floatArrayOf(5f, 6f), floatArrayOf(50f, 60f)) + + assertEquals(1, store.currentSeqLen) + + // Layer 0 keys + val k0 = store.readKeys(0) + assertEquals(1f, k0[0]) + assertEquals(2f, k0[1]) + + // Layer 1 keys + val k1 = store.readKeys(1) + assertEquals(5f, k1[0]) + assertEquals(6f, k1[1]) + } + + // --- Range reads --- + + @Test + fun readSubRange() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 2, maxSeqLen = 8) + + // Append 4 tokens + for (i in 0 until 4) { + store.appendToken(0, floatArrayOf(i.toFloat(), (i + 10).toFloat()), floatArrayOf(0f, 0f)) + } + + // Read only positions 1..3 + val keys = store.readKeys(0, startPos = 1, endPos = 3) + assertEquals(4, keys.size) // [1, numHeads=1] * 2 positions * headDim=2 + assertEquals(1f, keys[0]) + assertEquals(11f, keys[1]) + assertEquals(2f, keys[2]) + assertEquals(12f, keys[3]) + } + + // --- TensorStorage output --- + + @Test + fun readKeyStorageReturnsTensorStorage() { + val store = createStore(numLayers = 1, numHeads = 2, headDim = 4, maxSeqLen = 8) + store.appendToken(0, FloatArray(8) { it.toFloat() }, FloatArray(8)) + + val storage = store.readKeyStorage(0) + assertEquals(LogicalDType.FLOAT32, storage.logicalType) + assertEquals(Ownership.OWNED, storage.ownership) + assertTrue(storage.isMutable) + } + + // --- Eviction --- + + @Test + fun evictTruncatesCache() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 2, maxSeqLen = 8) + + for (i in 0 until 4) { + store.appendToken(0, floatArrayOf(i.toFloat(), 0f), floatArrayOf(0f, 0f)) + } + assertEquals(4, store.currentSeqLen) + + store.evict(fromPos = 2) + assertEquals(2, store.currentSeqLen) + + val keys = store.readKeys(0) + assertEquals(4, keys.size) // 2 positions * headDim=2 + assertEquals(0f, keys[0]) + assertEquals(1f, keys[2]) + } + + @Test + fun clearResetsEverything() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 2, maxSeqLen = 4) + store.appendToken(0, floatArrayOf(1f, 2f), floatArrayOf(3f, 4f)) + assertEquals(1, store.currentSeqLen) + + store.clear() + assertEquals(0, store.currentSeqLen) + } + + // --- Capacity --- + + @Test + fun appendBeyondCapacityThrows() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 2, maxSeqLen = 2) + store.appendToken(0, floatArrayOf(1f, 2f), floatArrayOf(3f, 4f)) + store.appendToken(0, floatArrayOf(5f, 6f), floatArrayOf(7f, 8f)) + + assertFailsWith { + store.appendToken(0, floatArrayOf(9f, 10f), floatArrayOf(11f, 12f)) + } + } + + // --- Validation --- + + @Test + fun invalidLayerIndexThrows() { + val store = createStore(numLayers = 2) + assertFailsWith { + store.appendToken(5, FloatArray(store.numHeads * store.headDim), FloatArray(store.numHeads * store.headDim)) + } + } + + @Test + fun wrongKeySizeThrows() { + val store = createStore(numLayers = 1, numHeads = 2, headDim = 4) + assertFailsWith { + store.appendToken(0, FloatArray(3), FloatArray(8)) // wrong key size + } + } + + // --- Memory report --- + + @Test + fun memoryReportIsAccurate() { + val store = createStore(numLayers = 2, numHeads = 4, headDim = 8, maxSeqLen = 16) + store.appendToken(0, FloatArray(32), FloatArray(32)) + store.appendToken(1, FloatArray(32), FloatArray(32)) + + val report = store.memoryReport() + assertEquals(2, report.numLayers) + assertEquals(4, report.numHeads) + assertEquals(8, report.headDim) + assertEquals(16, report.maxSeqLen) + assertEquals(1, report.currentSeqLen) + assertEquals(TensorEncoding.Dense(4), report.keyEncoding) + assertTrue(report.totalPhysicalBytes > 0) + assertTrue(report.utilizationRatio > 0.0) + assertTrue(report.utilizationRatio < 1.0) + } + + // --- Config validation --- + + @Test + fun invalidConfigThrows() { + assertFailsWith { + KvCacheConfig(numLayers = 0, numHeads = 4, headDim = 8, maxSeqLen = 16) + } + assertFailsWith { + KvCacheConfig(numLayers = 2, numHeads = 0, headDim = 8, maxSeqLen = 16) + } + } + + // --- Asymmetric K/V encoding config --- + + @Test + fun asymmetricConfigPreservesEncodings() { + val config = KvCacheConfig( + numLayers = 2, + numHeads = 4, + headDim = 64, + maxSeqLen = 512, + keyEncoding = TensorEncoding.Q8_0, + valueEncoding = TensorEncoding.Q4_K + ) + val store = DefaultKvCacheStore(config) + assertEquals(TensorEncoding.Q8_0, store.keyEncoding) + assertEquals(TensorEncoding.Q4_K, store.valueEncoding) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/MemoryPlannerTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/MemoryPlannerTest.kt new file mode 100644 index 00000000..9128b869 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/MemoryPlannerTest.kt @@ -0,0 +1,148 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.test.assertFailsWith + +class MemoryPlannerTest { + + @Test + fun cpuPlacementResolvesDirectly() { + val planner = MemoryPlanner(availableDevices = setOf(DeviceKind.CPU)) + val result = planner.resolve(Placement.CPU_HEAP) + assertEquals(DeviceKind.CPU, result.actual.device) + assertFalse(result.usedFallback) + } + + @Test + fun gpuPreferredFallsToCpuWhenNoGpu() { + val planner = MemoryPlanner(availableDevices = setOf(DeviceKind.CPU)) + val result = planner.resolve(Placement.GPU_PREFERRED) + assertEquals(DeviceKind.CPU, result.actual.device) + assertEquals(MemoryDomain.HOST_HEAP, result.actual.domain) // DEVICE_LOCAL falls to HOST_HEAP + assertTrue(result.usedFallback) + } + + @Test + fun gpuRequiredThrowsWhenNoGpu() { + val planner = MemoryPlanner(availableDevices = setOf(DeviceKind.CPU)) + val required = Placement( + device = DeviceKind.GPU, + domain = MemoryDomain.DEVICE_LOCAL, + requirement = Requirement.REQUIRED + ) + assertFailsWith { + planner.resolve(required) + } + } + + @Test + fun gpuResolvesDirectlyWhenAvailable() { + val planner = MemoryPlanner(availableDevices = setOf(DeviceKind.CPU, DeviceKind.GPU)) + val result = planner.resolve(Placement.GPU_PREFERRED) + assertEquals(DeviceKind.GPU, result.actual.device) + assertFalse(result.usedFallback) + } + + @Test + fun autoPicksBestDevice() { + val planner = MemoryPlanner(availableDevices = setOf(DeviceKind.CPU, DeviceKind.GPU)) + val result = planner.resolve(Placement(device = DeviceKind.AUTO)) + assertEquals(DeviceKind.GPU, result.actual.device) // GPU preferred over CPU + assertFalse(result.usedFallback) + } + + @Test + fun suggestWeightPlacementFileBacked() { + val planner = MemoryPlanner() + val p = planner.suggestWeightPlacement(isFileBacked = true) + assertEquals(MemoryDomain.MMAP_FILE, p.domain) + assertEquals(Residency.PERSISTENT, p.residency) + } + + @Test + fun suggestActivationPlacement() { + val planner = MemoryPlanner() + val p = planner.suggestActivationPlacement() + assertEquals(MemoryDomain.HOST_HEAP, p.domain) + assertEquals(Residency.TRANSIENT, p.residency) + } +} + +class MemoryTrackerTest { + + @Test + fun trackAndReport() { + val tracker = MemoryTracker() + + val s1 = TensorStorage( + shape = Shape(100), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(ByteArray(400)) + ) + val s2 = TensorStorage( + shape = Shape(256), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + buffer = BufferHandle.Borrowed(ByteArray(144)) + ) + + tracker.record("weight1", s1) + tracker.record("weight2_q4k", s2) + + val report = tracker.report() + assertEquals(2, report.tensorCount) + assertEquals(1, report.ownedCount) + assertEquals(1, report.borrowedCount) + assertEquals(400L + 1024L, report.totalLogicalBytes) // 100*4 + 256*4 + assertEquals(400L + 144L, report.totalPhysicalBytes) + } + + @Test + fun trackCopies() { + val tracker = MemoryTracker() + tracker.recordCopy("tensor_a", 1024) + tracker.recordCopy("tensor_b", 2048) + + val report = tracker.report() + assertEquals(2L, report.copyCount) + assertEquals(3072L, report.copyBytes) + } + + @Test + fun clearResetsState() { + val tracker = MemoryTracker() + tracker.record("x", TensorStorage( + shape = Shape(10), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(ByteArray(40)) + )) + tracker.recordCopy("x", 40) + tracker.clear() + + val report = tracker.report() + assertEquals(0, report.tensorCount) + assertEquals(0L, report.copyCount) + } + + @Test + fun fileBackedTracking() { + val tracker = MemoryTracker() + tracker.record("mmap_weight", TensorStorage( + shape = Shape(1000), + logicalType = LogicalDType.FLOAT16, + encoding = TensorEncoding.Dense(2), + buffer = BufferHandle.FileBacked("/model.bin", 0, 2000), + placement = Placement.MMAP_WEIGHTS + )) + + val report = tracker.report() + assertEquals(1, report.fileBackedCount) + assertEquals(2000L, report.fileBackedBytes) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/NonContiguousStorageTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/NonContiguousStorageTest.kt new file mode 100644 index 00000000..11192f95 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/NonContiguousStorageTest.kt @@ -0,0 +1,93 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotEquals +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class NonContiguousStorageTest { + + @Test + fun defaultStorage_stridesNull_isContiguousTrue() { + val storage = TensorStorage( + shape = Shape(4, 4), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(ByteArray(64)) + ) + assertNull(storage.strides) + assertTrue(storage.isContiguous) + } + + @Test + fun nonContiguous_stridesPreserved() { + val strides = longArrayOf(768, 1) + val storage = TensorStorage( + shape = Shape(1024, 768), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(ByteArray(1024 * 768 * 4)), + strides = strides, + isContiguous = false + ) + assertEquals(768L, storage.strides!![0]) + assertEquals(1L, storage.strides!![1]) + assertFalse(storage.isContiguous) + } + + @Test + fun equalityIncludesStrides() { + val buf = BufferHandle.Owned(ByteArray(64)) + val s1 = TensorStorage( + shape = Shape(4, 4), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = buf, + strides = longArrayOf(4, 1) + ) + val s2 = TensorStorage( + shape = Shape(4, 4), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = buf, + strides = longArrayOf(1, 4) // transposed strides + ) + assertNotEquals(s1, s2) + } + + @Test + fun equalityNullStridesMatch() { + val buf = BufferHandle.Owned(ByteArray(16)) + val s1 = TensorStorage(Shape(4), LogicalDType.FLOAT32, TensorEncoding.Dense(4), buf) + val s2 = TensorStorage(Shape(4), LogicalDType.FLOAT32, TensorEncoding.Dense(4), buf) + assertEquals(s1, s2) + } + + @Test + fun hashCodeDiffersWithDifferentStrides() { + val buf = BufferHandle.Owned(ByteArray(64)) + val s1 = TensorStorage(Shape(4, 4), LogicalDType.FLOAT32, TensorEncoding.Dense(4), buf, strides = longArrayOf(4, 1)) + val s2 = TensorStorage(Shape(4, 4), LogicalDType.FLOAT32, TensorEncoding.Dense(4), buf, strides = longArrayOf(1, 4)) + // Not guaranteed by contract but highly likely for different strides + assertNotEquals(s1.hashCode(), s2.hashCode()) + } + + @Test + fun memoryReport_nonContiguous_reportsCorrectBytes() { + val storage = TensorStorage( + shape = Shape(8, 8), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(ByteArray(256)), + strides = longArrayOf(8, 1), + isContiguous = false + ) + val report = storage.memoryReport() + // Physical/logical bytes computed from shape and encoding, not strides + assertEquals(256L, report.logicalBytes) // 64 * 4 + assertEquals(256L, report.physicalBytes) // Dense: 64 * 4 + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/PackedBlockStorageTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/PackedBlockStorageTest.kt new file mode 100644 index 00000000..2edc8be9 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/PackedBlockStorageTest.kt @@ -0,0 +1,94 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.Q4_KBlockTensorData +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class PackedBlockStorageTest { + + @Test + fun q4kImplementsPackedBlockStorage() { + val data = ByteArray(144) // 1 Q4_K block + val td = Q4_KBlockTensorData.fromRawBytes(Shape(256), data) + + assertTrue(td is PackedBlockStorage) + val packed = td as PackedBlockStorage + assertEquals(TensorEncoding.Q4_K, packed.encoding) + assertEquals(256, packed.blockSize) + assertEquals(1, packed.blockCount) + assertEquals(144L, packed.physicalBytes) + assertEquals(256L, packed.elementCount) + } + + @Test + fun q80ImplementsPackedBlockStorage() { + val data = ByteArray(34) // 1 Q8_0 block + val td = Q8_0BlockTensorData.fromRawBytes(Shape(32), data) + + assertTrue(td is PackedBlockStorage) + val packed = td as PackedBlockStorage + assertEquals(TensorEncoding.Q8_0, packed.encoding) + assertEquals(32, packed.blockSize) + assertEquals(1, packed.blockCount) + assertEquals(34L, packed.physicalBytes) + assertEquals(32L, packed.elementCount) + } + + @Test + fun q80DequantizeBlockProducesCorrectOutput() { + // Create a Q8_0 block: 2 bytes scale (f16 for 1.0) + 32 bytes codes + val data = ByteArray(34) + // Scale = 1.0 in f16: sign=0, exp=15, mant=0 → 0x3C00 (little-endian: 0x00, 0x3C) + data[0] = 0x00.toByte() + data[1] = 0x3C.toByte() + // Codes: 1, 2, 3, ... 32 + for (i in 0 until 32) { + data[2 + i] = (i + 1).toByte() + } + + val td = Q8_0BlockTensorData.fromRawBytes(Shape(32), data) + val packed = td as PackedBlockStorage + val output = FloatArray(32) + packed.dequantizeBlock(0, output) + + // output[i] = code[i] * scale = (i+1) * 1.0 + for (i in 0 until 32) { + assertEquals((i + 1).toFloat(), output[i], "Element $i") + } + } + + @Test + fun q80ToFloatArrayDequantizesAll() { + val data = ByteArray(34) + data[0] = 0x00.toByte() // scale = 1.0 f16 + data[1] = 0x3C.toByte() + for (i in 0 until 32) { + data[2 + i] = (i + 1).toByte() + } + + val td = Q8_0BlockTensorData.fromRawBytes(Shape(32), data) + val packed = td as PackedBlockStorage + val floats = packed.toFloatArray() + + assertEquals(32, floats.size) + assertEquals(1.0f, floats[0]) + assertEquals(32.0f, floats[31]) + } + + @Test + fun packedBlockStorageToTensorStorage() { + val data = ByteArray(144) + val td = Q4_KBlockTensorData.fromRawBytes(Shape(256), data) + val packed = td as PackedBlockStorage + val storage = packed.toTensorStorage() + + assertEquals(LogicalDType.FLOAT32, storage.logicalType) + assertEquals(TensorEncoding.Q4_K, storage.encoding) + assertEquals(Ownership.BORROWED, storage.ownership) + assertEquals(144L, storage.physicalBytes) + assertEquals(1024L, storage.logicalBytes) // 256 * 4 + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/Q4KDequantizationTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/Q4KDequantizationTest.kt new file mode 100644 index 00000000..d99a2fd5 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/Q4KDequantizationTest.kt @@ -0,0 +1,168 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.Q4_KBlockTensorData +import sk.ainet.lang.tensor.data.Q4_KTensorData +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class Q4KDequantizationTest { + + /** + * Build a 144-byte Q4_K block with controlled values. + * + * Layout: + * - bytes [0..1]: f16 d (main scale) + * - bytes [2..3]: f16 dMin (minimum scale) + * - bytes [4..15]: packed 12-bit scale/min indices (12 bytes) + * - bytes [16..143]: 4-bit codes (128 bytes, 2 codes per byte) + */ + private fun buildQ4KBlock( + d: Float = 1.0f, + dMin: Float = 0.0f, + codeValue: Int = 0 + ): ByteArray { + val block = ByteArray(Q4_KTensorData.BYTES_PER_BLOCK) // 144 + + // d as f16 little-endian + val dBits = floatToHalf(d) + block[0] = (dBits and 0xFF).toByte() + block[1] = ((dBits shr 8) and 0xFF).toByte() + + // dMin as f16 little-endian + val dMinBits = floatToHalf(dMin) + block[2] = (dMinBits and 0xFF).toByte() + block[3] = ((dMinBits shr 8) and 0xFF).toByte() + + // Scale/min indices: all 63 for scale, all 0 for min + // Each sub-block uses 12 bits: 6 for scaleIdx + 6 for minIdx + // 8 sub-blocks * 12 bits = 96 bits = 12 bytes + // scaleIdx=63 (0x3F), minIdx=0 (0x00) → 12 bits per sub-block = 0xFC0 → little-endian + for (i in 0 until 12) { + // Pack all scale indices as 63 and min indices as 0 + // Bit pattern per sub-block: scaleIdx=111111, minIdx=000000 + // In 12-bit groups: 0b111111_000000 = 0xFC0 + val bitStart = i * 8 + var byteVal = 0 + for (bit in 0 until 8) { + val globalBit = bitStart + bit + val subBlock = globalBit / 12 + val bitInSubBlock = globalBit % 12 + if (subBlock < 8 && bitInSubBlock < 6) { + // This is a scale index bit — set to 1 (index = 63) + byteVal = byteVal or (1 shl bit) + } + // min index bits stay 0 + } + block[4 + i] = byteVal.toByte() + } + + // 4-bit codes: fill all with codeValue (0..15) + val codeByte = ((codeValue and 0x0F) or ((codeValue and 0x0F) shl 4)).toByte() + for (i in 16 until 144) { + block[i] = codeByte + } + + return block + } + + private fun floatToHalf(value: Float): Int { + val bits = value.toRawBits() + val sign = (bits shr 16) and 0x8000 + val exponent = ((bits shr 23) and 0xFF) - 127 + val mantissa = bits and 0x7FFFFF + + return when { + exponent >= 16 -> sign or 0x7C00 // overflow → infinity + exponent >= -14 -> sign or ((exponent + 15) shl 10) or (mantissa shr 13) + else -> sign // underflow → zero + } + } + + @Test + fun dequantizeBlock_uniformCodes_producesExpectedOutput() { + // d=1.0, dMin=0.0, all scale indices=63, all codes=5 + // scale = d * (63/63) = 1.0, min = 0.0 + // output = code * scale + min = 5 * 1.0 + 0.0 = 5.0 + val block = buildQ4KBlock(d = 1.0f, dMin = 0.0f, codeValue = 5) + val td = Q4_KBlockTensorData.fromRawBytes(Shape(256), block) + + val output = FloatArray(256) + td.dequantizeBlock(0, output) + + for (i in 0 until 256) { + assertEquals(5.0f, output[i], "Element $i should be 5.0") + } + } + + @Test + fun getCode_lowAndHighNibble_correct() { + val block = ByteArray(144) + // Put a known byte at code position: byte at offset 16 + // Low nibble = 0xA (10), high nibble = 0x5 (5) + block[16] = 0x5A.toByte() + + val td = Q4_KBlockTensorData.fromRawBytes(Shape(256), block) + // Element 0 → low nibble of byte 16 → 0xA = 10 + assertEquals(10, td.getCode(0, 0)) + // Element 1 → high nibble of byte 16 → 0x5 = 5 + assertEquals(5, td.getCode(0, 1)) + } + + @Test + fun toFloatArray_multiBlock_concatenatesBlocks() { + // 2 blocks = 512 elements + val data = ByteArray(288) // 2 * 144 + // Both blocks: d=1.0, dMin=0.0, all codes=0 + val block1 = buildQ4KBlock(d = 1.0f, dMin = 0.0f, codeValue = 0) + block1.copyInto(data, 0) + block1.copyInto(data, 144) + + val td = Q4_KBlockTensorData.fromRawBytes(Shape(512), data) + val floats = (td as PackedBlockStorage).toFloatArray() + + assertEquals(512, floats.size) + } + + @Test + fun dequantizeBlock_outOfBoundsIndex_throws() { + val block = ByteArray(144) + val td = Q4_KBlockTensorData.fromRawBytes(Shape(256), block) + val output = FloatArray(256) + + assertFailsWith { + td.dequantizeBlock(-1, output) + } + assertFailsWith { + td.dequantizeBlock(1, output) // only 1 block (index 0) + } + } + + @Test + fun physicalBytes_matchesExpected() { + val block = ByteArray(144) + val td = Q4_KBlockTensorData.fromRawBytes(Shape(256), block) + val packed = td as PackedBlockStorage + + assertEquals(144L, packed.physicalBytes) + assertEquals(256L, packed.elementCount) + assertEquals(1, packed.blockCount) + assertEquals(256, packed.blockSize) + } + + @Test + fun dequantizeBlock_zeroCodes_producesMinValues() { + // d=1.0, dMin=0.0, all codes=0 → output = 0*scale + min = 0.0 + val block = buildQ4KBlock(d = 1.0f, dMin = 0.0f, codeValue = 0) + val td = Q4_KBlockTensorData.fromRawBytes(Shape(256), block) + + val output = FloatArray(256) + td.dequantizeBlock(0, output) + + for (i in 0 until 256) { + assertEquals(0.0f, output[i], "Element $i should be 0.0 for zero codes") + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/StorageToTensorDataTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/StorageToTensorDataTest.kt new file mode 100644 index 00000000..f3d1588c --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/StorageToTensorDataTest.kt @@ -0,0 +1,110 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData +import sk.ainet.lang.tensor.data.DenseIntArrayTensorData +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.tensor.data.IntArrayTensorData +import sk.ainet.lang.tensor.data.Q4_KBlockTensorData +import sk.ainet.lang.tensor.data.Q4_KTensorData +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData +import sk.ainet.lang.tensor.data.Q8_0TensorData +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class StorageToTensorDataTest { + + @Test + fun roundTripDenseFloat32() { + // TensorData → TensorStorage → TensorData + val original = DenseFloatArrayTensorData(Shape(3), floatArrayOf(1f, 2f, 3f)) + val storage = TensorStorageFactory.fromTensorData(original) + val restored = TensorStorageFactory.toTensorData(storage) + + assertTrue(restored is FloatArrayTensorData<*>) + val floats = (restored as FloatArrayTensorData<*>).buffer + assertEquals(3, floats.size) + assertEquals(1f, floats[0]) + assertEquals(2f, floats[1]) + assertEquals(3f, floats[2]) + } + + @Test + fun roundTripDenseInt32() { + val original = DenseIntArrayTensorData(Shape(4), intArrayOf(10, 20, 30, 40)) + val storage = TensorStorageFactory.fromTensorData(original) + val restored = TensorStorageFactory.toTensorData(storage) + + assertTrue(restored is IntArrayTensorData<*>) + val ints = (restored as IntArrayTensorData<*>).buffer + assertEquals(4, ints.size) + assertEquals(10, ints[0]) + assertEquals(40, ints[3]) + } + + @Test + fun roundTripQ4K() { + val rawBytes = ByteArray(144) // 1 Q4_K block + rawBytes[10] = 42 // put something non-zero to verify identity + val original = Q4_KBlockTensorData.fromRawBytes(Shape(256), rawBytes) + val storage = TensorStorageFactory.fromTensorData(original) + + assertEquals(TensorEncoding.Q4_K, storage.encoding) + + val restored = TensorStorageFactory.toTensorData(storage) + assertTrue(restored is Q4_KTensorData) + assertEquals(256, restored.shape.volume) + assertEquals(42, (restored as Q4_KTensorData).packedData[10]) + } + + @Test + fun roundTripQ80() { + // Build a Q8_0 block: scale=1.0 (f16 0x3C00) + 32 code bytes + val rawBytes = ByteArray(34) + rawBytes[0] = 0x00 + rawBytes[1] = 0x3C + for (i in 0 until 32) rawBytes[2 + i] = (i + 1).toByte() + + val original = Q8_0BlockTensorData.fromRawBytes(Shape(32), rawBytes) + val storage = TensorStorageFactory.fromTensorData(original) + + assertEquals(TensorEncoding.Q8_0, storage.encoding) + + val restored = TensorStorageFactory.toTensorData(storage) + assertTrue(restored is Q8_0TensorData) + val q80 = restored as Q8_0TensorData + assertEquals(32, q80.shape.volume) + // Verify codes are intact + assertEquals(1.toByte(), q80.getCode(0, 0)) + assertEquals(32.toByte(), q80.getCode(0, 31)) + } + + @Test + fun toTensorDataFromBorrowedFloat32() { + // Create storage from raw bytes directly + val floatBytes = ByteArray(12) // 3 floats + // 1.0f = 0x3F800000 little-endian + floatBytes[0] = 0x00; floatBytes[1] = 0x00; floatBytes[2] = 0x80.toByte(); floatBytes[3] = 0x3F + // 2.0f = 0x40000000 + floatBytes[4] = 0x00; floatBytes[5] = 0x00; floatBytes[6] = 0x00; floatBytes[7] = 0x40 + // 3.0f = 0x40400000 + floatBytes[8] = 0x00; floatBytes[9] = 0x00; floatBytes[10] = 0x40; floatBytes[11] = 0x40 + + val storage = TensorStorageFactory.fromRawBytes( + shape = Shape(3), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + data = floatBytes + ) + + val td = TensorStorageFactory.toTensorData(storage) + assertTrue(td is FloatArrayTensorData<*>) + assertEquals(1f, (td as FloatArrayTensorData<*>).buffer[0]) + assertEquals(2f, td.buffer[1]) + assertEquals(3f, td.buffer[2]) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TensorStorageContractTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TensorStorageContractTest.kt new file mode 100644 index 00000000..fc02605b --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TensorStorageContractTest.kt @@ -0,0 +1,222 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.types.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class TensorStorageContractTest { + + // --- LogicalDType --- + + @Test + fun logicalDTypeFromDTypeRoundTrips() { + assertEquals(LogicalDType.FLOAT32, LogicalDType.fromDType(FP32)) + assertEquals(LogicalDType.FLOAT16, LogicalDType.fromDType(FP16)) + assertEquals(LogicalDType.BFLOAT16, LogicalDType.fromDType(BF16)) + assertEquals(LogicalDType.INT32, LogicalDType.fromDType(Int32)) + assertEquals(LogicalDType.INT4, LogicalDType.fromDType(Int4)) + assertEquals(LogicalDType.TERNARY, LogicalDType.fromDType(Ternary)) + assertEquals(LogicalDType.UINT8, LogicalDType.fromDType(UInt8)) + } + + @Test + fun logicalDTypeSizeInBytes() { + assertEquals(4, LogicalDType.FLOAT32.sizeInBytes) + assertEquals(2, LogicalDType.FLOAT16.sizeInBytes) + assertEquals(2, LogicalDType.BFLOAT16.sizeInBytes) + assertEquals(4, LogicalDType.INT32.sizeInBytes) + assertEquals(1, LogicalDType.INT8.sizeInBytes) + assertEquals(1, LogicalDType.INT4.sizeInBytes) // 4 bits rounds up to 1 byte + } + + @Test + fun logicalDTypeProperties() { + assertTrue(LogicalDType.FLOAT32.isFloatingPoint) + assertTrue(LogicalDType.FLOAT32.isSigned) + assertFalse(LogicalDType.UINT8.isSigned) + assertFalse(LogicalDType.INT32.isFloatingPoint) + } + + // --- TensorEncoding --- + + @Test + fun denseEncodingPhysicalBytes() { + val fp32Dense = TensorEncoding.Dense(bytesPerElement = 4) + assertEquals(4000L, fp32Dense.physicalBytes(1000)) + assertEquals("Dense(4B)", fp32Dense.name) + } + + @Test + fun q4kEncodingPhysicalBytes() { + // 256 elements per 144-byte block + assertEquals(144L, TensorEncoding.Q4_K.physicalBytes(256)) + assertEquals(288L, TensorEncoding.Q4_K.physicalBytes(257)) // 2 blocks needed + assertEquals(144L, TensorEncoding.Q4_K.physicalBytes(1)) // at least 1 block + } + + @Test + fun q80EncodingPhysicalBytes() { + // 32 elements per 34-byte block + assertEquals(34L, TensorEncoding.Q8_0.physicalBytes(32)) + assertEquals(68L, TensorEncoding.Q8_0.physicalBytes(33)) // 2 blocks + } + + @Test + fun ternaryEncodingPhysicalBytes() { + assertEquals(1L, TensorEncoding.TernaryPacked.physicalBytes(4)) + assertEquals(2L, TensorEncoding.TernaryPacked.physicalBytes(5)) + } + + // --- BufferHandle --- + + @Test + fun ownedBufferProperties() { + val data = ByteArray(100) + val handle = BufferHandle.Owned(data) + assertEquals(100L, handle.sizeInBytes) + assertTrue(handle.isMutable) + assertEquals(Ownership.OWNED, handle.ownership) + } + + @Test + fun borrowedBufferProperties() { + val data = ByteArray(64) + val handle = BufferHandle.Borrowed(data, isMutable = false) + assertEquals(64L, handle.sizeInBytes) + assertFalse(handle.isMutable) + assertEquals(Ownership.BORROWED, handle.ownership) + } + + @Test + fun aliasedBufferProperties() { + val parent = BufferHandle.Owned(ByteArray(100)) + val alias = BufferHandle.Aliased(parent, byteOffset = 10, sizeInBytes = 50) + assertEquals(50L, alias.sizeInBytes) + assertTrue(alias.isMutable) // inherits parent mutability + assertEquals(Ownership.ALIASED, alias.ownership) + } + + @Test + fun fileBackedBufferProperties() { + val handle = BufferHandle.FileBacked(path = "/model/weights.bin", fileOffset = 0, sizeInBytes = 1024) + assertEquals(1024L, handle.sizeInBytes) + assertFalse(handle.isMutable) + assertEquals(Ownership.FILE_BACKED, handle.ownership) + } + + @Test + fun deviceResidentBufferProperties() { + val handle = BufferHandle.DeviceResident( + deviceId = "gpu:0", backendHandle = "opaque", sizeInBytes = 2048, isMutable = true + ) + assertEquals(2048L, handle.sizeInBytes) + assertTrue(handle.isMutable) + assertEquals(Ownership.DEVICE_RESIDENT, handle.ownership) + } + + @Test + fun aliasedBufferWithOffsetAndSize() { + val parent = BufferHandle.Owned(ByteArray(200)) + val alias = BufferHandle.Aliased(parent, byteOffset = 100, sizeInBytes = 100) + assertEquals(100L, alias.sizeInBytes) + assertEquals(100L, alias.byteOffset) + } + + // --- Placement --- + + @Test + fun defaultPlacementPresets() { + val cpuHeap = Placement.CPU_HEAP + assertEquals(DeviceKind.CPU, cpuHeap.device) + assertEquals(MemoryDomain.HOST_HEAP, cpuHeap.domain) + assertEquals(Residency.TRANSIENT, cpuHeap.residency) + + val mmapWeights = Placement.MMAP_WEIGHTS + assertEquals(MemoryDomain.MMAP_FILE, mmapWeights.domain) + assertEquals(Residency.PERSISTENT, mmapWeights.residency) + + val gpuPreferred = Placement.GPU_PREFERRED + assertEquals(DeviceKind.GPU, gpuPreferred.device) + assertEquals(DeviceKind.CPU, gpuPreferred.fallback) + assertEquals(Requirement.PREFERRED, gpuPreferred.requirement) + } + + // --- TensorStorage --- + + @Test + fun tensorStorageDenseFloat32() { + val shape = Shape(2, 3) + val data = ByteArray(24) // 6 elements * 4 bytes + val storage = TensorStorage( + shape = shape, + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(data) + ) + assertEquals(6L, storage.elementCount) + assertEquals(24L, storage.logicalBytes) // 6 * 4 + assertEquals(24L, storage.physicalBytes) + assertFalse(storage.isFileBacked) + assertFalse(storage.isAlias) + assertTrue(storage.isMutable) + assertEquals(Ownership.OWNED, storage.ownership) + } + + @Test + fun tensorStorageQ4KPacked() { + val shape = Shape(256) + val data = ByteArray(144) // 1 Q4_K block + val storage = TensorStorage( + shape = shape, + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + buffer = BufferHandle.Borrowed(data) + ) + assertEquals(256L, storage.elementCount) + assertEquals(1024L, storage.logicalBytes) // 256 * 4 (FP32 logical) + assertEquals(144L, storage.physicalBytes) // 1 Q4_K block + assertFalse(storage.isMutable) + assertEquals(Ownership.BORROWED, storage.ownership) + } + + @Test + fun tensorStorageFileBackedWeights() { + val shape = Shape(1024, 768) + val storage = TensorStorage( + shape = shape, + logicalType = LogicalDType.FLOAT16, + encoding = TensorEncoding.Dense(2), + buffer = BufferHandle.FileBacked("/model.bin", fileOffset = 4096, sizeInBytes = 1024L * 768 * 2), + placement = Placement.MMAP_WEIGHTS + ) + assertTrue(storage.isFileBacked) + assertFalse(storage.isMutable) + assertEquals(Residency.PERSISTENT, storage.placement.residency) + assertEquals(MemoryDomain.MMAP_FILE, storage.placement.domain) + } + + // --- StorageMemoryReport --- + + @Test + fun memoryReportForQ4K() { + val shape = Shape(256) + val storage = TensorStorage( + shape = shape, + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Q4_K, + buffer = BufferHandle.Borrowed(ByteArray(144)) + ) + val report = storage.memoryReport() + assertEquals(LogicalDType.FLOAT32, report.logicalType) + assertEquals("Q4_K", report.encoding.name) + assertEquals(Ownership.BORROWED, report.ownership) + assertEquals(1024L, report.logicalBytes) + assertEquals(144L, report.physicalBytes) + assertTrue(report.compressionRatio > 1.0) // Q4_K is smaller than dense FP32 + assertFalse(report.isFileBacked) + assertFalse(report.isMutable) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TernaryDequantizationTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TernaryDequantizationTest.kt new file mode 100644 index 00000000..37a2d946 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TernaryDequantizationTest.kt @@ -0,0 +1,106 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.Ternary2BitTensorData +import sk.ainet.lang.tensor.data.toFloatArray +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class TernaryDequantizationTest { + + @Test + fun dequantizeBlock_allMinusOnes_producesNegativeScale() { + // Encoding: 0→-1, so 0x00 = four -1 values per byte + val packed = ByteArray(2) { 0x00 } // 8 elements, all -1 + val td = Ternary2BitTensorData(Shape(8), packed, scale = 2.0f) + val ps = td as PackedBlockStorage + + val output = FloatArray(8) + ps.dequantizeBlock(0, output) + + for (i in 0 until 8) { + assertEquals(-2.0f, output[i], "Element $i should be -1 * 2.0 = -2.0") + } + } + + @Test + fun dequantizeBlock_allZeros_producesZeros() { + // Encoding: 1→0, so 0x55 = 01_01_01_01 = four 0 values per byte + val packed = ByteArray(2) { 0x55 } + val td = Ternary2BitTensorData(Shape(8), packed, scale = 5.0f) + val ps = td as PackedBlockStorage + + val output = FloatArray(8) + ps.dequantizeBlock(0, output) + + for (i in 0 until 8) { + assertEquals(0.0f, output[i], "Element $i should be 0 * 5.0 = 0.0") + } + } + + @Test + fun dequantizeBlock_allPlusOnes_producesPositiveScale() { + // Encoding: 2→+1, so 0xAA = 10_10_10_10 = four +1 values per byte + val packed = ByteArray(2) { 0xAA.toByte() } + val td = Ternary2BitTensorData(Shape(8), packed, scale = 3.0f) + val ps = td as PackedBlockStorage + + val output = FloatArray(8) + ps.dequantizeBlock(0, output) + + for (i in 0 until 8) { + assertEquals(3.0f, output[i], "Element $i should be +1 * 3.0 = 3.0") + } + } + + @Test + fun dequantizeBlock_mixedValues_matchesToFloatArray() { + // Mixed: -1, 0, +1, -1 per byte → 0b10_01_00 = 0x00+bits + // Byte: bits [1:0]=00(-1), [3:2]=01(0), [5:4]=10(+1), [7:6]=00(-1) + // = 0b00_10_01_00 = 0x24 + val packed = byteArrayOf(0x24, 0x24) + val td = Ternary2BitTensorData(Shape(8), packed, scale = 1.0f) + + // Verify via PackedBlockStorage + val ps = td as PackedBlockStorage + val output = FloatArray(8) + ps.dequantizeBlock(0, output) + + // Also verify via extension function + val expected = td.toFloatArray() + + for (i in 0 until 8) { + assertEquals(expected[i], output[i], "Element $i: dequantizeBlock should match toFloatArray") + } + } + + @Test + fun dequantizeBlock_withOutputOffset_writesAtOffset() { + val packed = ByteArray(1) { 0xAA.toByte() } // 4 elements, all +1 + val td = Ternary2BitTensorData(Shape(4), packed, scale = 1.0f) + val ps = td as PackedBlockStorage + + val output = FloatArray(14) // larger than needed + ps.dequantizeBlock(0, output, outputOffset = 10) + + // Elements [0..9] should be untouched (0.0) + for (i in 0 until 10) { + assertEquals(0.0f, output[i], "Element $i should be untouched") + } + // Elements [10..13] should be 1.0 + for (i in 10 until 14) { + assertEquals(1.0f, output[i], "Element $i should be 1.0") + } + } + + @Test + fun dequantizeBlock_invalidBlockIndex_throws() { + val packed = ByteArray(1) { 0x55 } + val td = Ternary2BitTensorData(Shape(4), packed) as PackedBlockStorage + + assertFailsWith { + td.dequantizeBlock(1, FloatArray(4)) // only block 0 valid + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TransferOpsTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TransferOpsTest.kt new file mode 100644 index 00000000..06ef995c --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TransferOpsTest.kt @@ -0,0 +1,135 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.Shape +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotSame +import kotlin.test.assertSame + +class TransferOpsTest { + + private fun ownedStorage(bytes: ByteArray = ByteArray(16) { it.toByte() }) = TensorStorage( + shape = Shape(4), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Owned(bytes), + placement = Placement.CPU_HEAP + ) + + private fun borrowedStorage() = TensorStorage( + shape = Shape(4), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.Borrowed(ByteArray(16) { it.toByte() }), + placement = Placement(device = DeviceKind.CPU, domain = MemoryDomain.MMAP_FILE) + ) + + // --- copyMaterialize --- + + @Test + fun copyMaterialize_ownedBuffer_producesIndependentCopy() { + val original = ByteArray(16) { it.toByte() } + val storage = ownedStorage(original) + val copy = storage.copyMaterialize() + + assertEquals(Ownership.OWNED, copy.ownership) + assertEquals(storage.shape, copy.shape) + assertEquals(storage.logicalType, copy.logicalType) + assertEquals(storage.encoding, copy.encoding) + assertEquals(MemoryDomain.HOST_HEAP, copy.placement.domain) + + // Modifying original doesn't affect copy + original[0] = 99 + val copyData = (copy.buffer as BufferHandle.Owned).data + assertEquals(0, copyData[0]) + } + + @Test + fun copyMaterialize_borrowedBuffer_producesOwnedCopy() { + val storage = borrowedStorage() + val copy = storage.copyMaterialize() + + assertEquals(Ownership.OWNED, copy.ownership) + assertEquals(MemoryDomain.HOST_HEAP, copy.placement.domain) + } + + @Test + fun copyMaterialize_fileBackedBuffer_throwsUnsupported() { + val storage = TensorStorage( + shape = Shape(4), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.FileBacked("/model.bin", 0, 16), + placement = Placement.MMAP_WEIGHTS + ) + assertFailsWith { + storage.copyMaterialize() + } + } + + @Test + fun copyMaterialize_deviceResidentBuffer_throwsUnsupported() { + val storage = TensorStorage( + shape = Shape(4), + logicalType = LogicalDType.FLOAT32, + encoding = TensorEncoding.Dense(4), + buffer = BufferHandle.DeviceResident("gpu:0", "opaque", 16, true) + ) + assertFailsWith { + storage.copyMaterialize() + } + } + + // --- copyToHost --- + + @Test + fun copyToHost_alreadyOnHost_returnsSameInstance() { + val storage = ownedStorage() + val result = storage.copyToHost() + assertSame(storage, result) + } + + @Test + fun copyToHost_nonHostPlacement_copies() { + val storage = borrowedStorage() // domain = MMAP_FILE, not HOST_HEAP + val result = storage.copyToHost() + assertNotSame(storage, result) + assertEquals(Ownership.OWNED, result.ownership) + assertEquals(MemoryDomain.HOST_HEAP, result.placement.domain) + } + + // --- copyToDevice --- + + @Test + fun copyToDevice_cpu_delegatesToCopyToHost() { + val storage = ownedStorage() + val result = storage.copyToDevice(DeviceKind.CPU) + assertSame(storage, result) // already on CPU heap + } + + @Test + fun copyToDevice_gpu_throwsUnsupported() { + val storage = ownedStorage() + assertFailsWith { + storage.copyToDevice(DeviceKind.GPU) + } + } + + // --- repackTo --- + + @Test + fun repackTo_sameEncoding_returnsSameInstance() { + val storage = ownedStorage() + val result = storage.repackTo(TensorEncoding.Dense(4)) + assertSame(storage, result) + } + + @Test + fun repackTo_differentEncoding_throwsUnsupported() { + val storage = ownedStorage() + assertFailsWith { + storage.repackTo(TensorEncoding.Q4_K) + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TurboQuantKvCacheStoreTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TurboQuantKvCacheStoreTest.kt new file mode 100644 index 00000000..03b1ce1f --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/storage/TurboQuantKvCacheStoreTest.kt @@ -0,0 +1,228 @@ +package sk.ainet.lang.tensor.storage + +import sk.ainet.lang.tensor.ops.turboquant.TurboQuantConfig +import kotlin.math.abs +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Tests for [TurboQuantKvCacheStore] — the compressed KV cache. + */ +class TurboQuantKvCacheStoreTest { + + private fun createStore( + numLayers: Int = 1, + numHeads: Int = 2, + headDim: Int = 64, + maxSeqLen: Int = 16, + bits: Int = 4, + useQjl: Boolean = false + ): TurboQuantKvCacheStore { + val config = KvCacheConfig( + numLayers = numLayers, + numHeads = numHeads, + headDim = headDim, + maxSeqLen = maxSeqLen, + keyEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = bits), + valueEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = bits) + ) + val quantConfig = if (useQjl) { + TurboQuantConfig.polarPlusQjl(bits = bits) + } else { + TurboQuantConfig.polarOnly(bits = bits) + } + return TurboQuantKvCacheStore(config, quantConfig, quantConfig) + } + + @Test + fun appendAndReadBasic() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 64) + val key = FloatArray(64) { (it - 32).toFloat() / 32f } + val value = FloatArray(64) { (it - 32).toFloat() / 64f } + + store.appendToken(0, key, value) + assertEquals(1, store.currentSeqLen) + + val readK = store.readKeys(0) + val readV = store.readValues(0) + + assertEquals(64, readK.size) + assertEquals(64, readV.size) + + // Check reconstruction accuracy (4-bit should be reasonable) + var maxKeyError = 0f + for (i in key.indices) { + maxKeyError = maxOf(maxKeyError, abs(key[i] - readK[i])) + } + assertTrue(maxKeyError < 0.5f, + "4-bit TurboQuant key reconstruction error should be < 0.5, got $maxKeyError") + } + + @Test + fun multipleTokens() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 64, maxSeqLen = 8) + + for (t in 0 until 4) { + val key = FloatArray(64) { (it + t).toFloat() / 64f } + val value = FloatArray(64) { (it - t).toFloat() / 64f } + store.appendToken(0, key, value) + } + + assertEquals(4, store.currentSeqLen) + + val allKeys = store.readKeys(0) + // [numHeads=1, seqLen=4, headDim=64] + assertEquals(1 * 4 * 64, allKeys.size) + } + + @Test + fun multipleHeads() { + val store = createStore(numLayers = 1, numHeads = 4, headDim = 64) + + val key = FloatArray(4 * 64) { it.toFloat() / 256f } + val value = FloatArray(4 * 64) { -it.toFloat() / 256f } + store.appendToken(0, key, value) + + val readK = store.readKeys(0) + assertEquals(4 * 1 * 64, readK.size) + } + + @Test + fun rangRead() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 64, maxSeqLen = 8) + + for (t in 0 until 4) { + store.appendToken(0, FloatArray(64) { t.toFloat() }, FloatArray(64)) + } + + val partial = store.readKeys(0, startPos = 1, endPos = 3) + assertEquals(1 * 2 * 64, partial.size) // 2 positions + } + + @Test + fun eviction() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 64, maxSeqLen = 8) + + for (t in 0 until 4) { + store.appendToken(0, FloatArray(64), FloatArray(64)) + } + assertEquals(4, store.currentSeqLen) + + store.evict(2) + assertEquals(2, store.currentSeqLen) + } + + @Test + fun clear() { + val store = createStore() + store.appendToken(0, FloatArray(2 * 64), FloatArray(2 * 64)) + assertEquals(1, store.currentSeqLen) + + store.clear() + assertEquals(0, store.currentSeqLen) + } + + @Test + fun capacityOverflow() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 64, maxSeqLen = 2) + store.appendToken(0, FloatArray(64), FloatArray(64)) + store.appendToken(0, FloatArray(64), FloatArray(64)) + + assertFailsWith { + store.appendToken(0, FloatArray(64), FloatArray(64)) + } + } + + @Test + fun compressionRatio() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 128, maxSeqLen = 8, bits = 4) + + for (t in 0 until 4) { + store.appendToken(0, FloatArray(128) { it.toFloat() }, FloatArray(128)) + } + + val report = store.memoryReport() + // 4-bit should compress significantly vs FP32 + assertTrue(report.compressionRatio > 1.5, + "4-bit TurboQuant should compress at least 1.5x, got ${report.compressionRatio}") + } + + @Test + fun qjlVariant() { + val store = createStore(numLayers = 1, numHeads = 1, headDim = 64, bits = 4, useQjl = true) + + val key = FloatArray(64) { (it - 32).toFloat() / 32f } + store.appendToken(0, key, FloatArray(64)) + + val readK = store.readKeys(0) + assertEquals(64, readK.size) + } + + @Test + fun multipleLayers() { + val store = createStore(numLayers = 2, numHeads = 1, headDim = 64) + + val key0 = FloatArray(64) { 1f } + val key1 = FloatArray(64) { -1f } + + store.appendToken(0, key0, FloatArray(64)) + store.appendToken(1, key1, FloatArray(64)) + + assertEquals(1, store.currentSeqLen) + + val readK0 = store.readKeys(0) + val readK1 = store.readKeys(1) + + // Layer 0 should reconstruct toward positive, layer 1 toward negative + val avgK0 = readK0.sum() / readK0.size + val avgK1 = readK1.sum() / readK1.size + assertTrue(avgK0 > avgK1, "Layer 0 (pos) avg ($avgK0) should > layer 1 (neg) avg ($avgK1)") + } + + @Test + fun memoryReportAccurate() { + val store = createStore(numLayers = 2, numHeads = 2, headDim = 64, maxSeqLen = 8) + store.appendToken(0, FloatArray(128), FloatArray(128)) + store.appendToken(1, FloatArray(128), FloatArray(128)) + + val report = store.memoryReport() + assertEquals(2, report.numLayers) + assertEquals(2, report.numHeads) + assertEquals(64, report.headDim) + assertEquals(1, report.currentSeqLen) + assertTrue(report.totalPhysicalBytes > 0) + assertTrue(report.totalLogicalBytes > 0) + } + + @Test + fun asymmetricKeyValueConfig() { + val config = KvCacheConfig( + numLayers = 1, numHeads = 1, headDim = 64, maxSeqLen = 8, + keyEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 8), + valueEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 4) + ) + val store = TurboQuantKvCacheStore( + config, + keyConfig = TurboQuantConfig.polarOnly(bits = 8), + valueConfig = TurboQuantConfig.polarOnly(bits = 4) + ) + + val input = FloatArray(64) { (it - 32).toFloat() / 32f } + store.appendToken(0, input, input) + + val readK = store.readKeys(0) + val readV = store.readValues(0) + + // 8-bit keys should be more accurate than 4-bit values + var keyError = 0f + var valError = 0f + for (i in input.indices) { + keyError += abs(input[i] - readK[i]) + valError += abs(input[i] - readV[i]) + } + assertTrue(keyError < valError, + "8-bit keys ($keyError) should have less error than 4-bit values ($valError)") + } +} diff --git a/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/StorageBenchmarks.kt b/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/StorageBenchmarks.kt new file mode 100644 index 00000000..23a9398f --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/StorageBenchmarks.kt @@ -0,0 +1,143 @@ +package sk.ainet.lang.tensor + +import kotlinx.benchmark.* +import sk.ainet.lang.tensor.data.* +import sk.ainet.lang.tensor.storage.* +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.FP32 +import kotlin.random.Random + +/** + * JMH benchmarks for the memory-first storage layer. + * + * Run: ./gradlew :skainet-lang:skainet-lang-core:jvmBenchmark + */ + +// --- Array creation: borrowed (wrap) vs copied (from) --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class ArrayCreationBenchmark { + private val factory = DenseTensorDataFactory() + private val shape = Shape(1024, 1024) // 1M elements + private lateinit var floatData: FloatArray + + @Setup + public fun setup() { + floatData = FloatArray(1024 * 1024) { Random.nextFloat() } + } + + @Benchmark + public fun wrapFloatArray_zeroCopy(): TensorData = + factory.wrapFloatArray(shape, FP32::class, floatData) + + @Benchmark + public fun fromFloatArray_copy(): TensorData = + factory.fromFloatArray(shape, FP32::class, floatData) +} + +// --- Dequantization throughput --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class DequantizationBenchmark { + private lateinit var q4kData: Q4_KBlockTensorData + private lateinit var q80Data: Q8_0BlockTensorData + private lateinit var ternaryData: Ternary2BitTensorData + + @Setup + public fun setup() { + // Q4_K: 100 blocks = 25600 elements + val q4kBytes = ByteArray(100 * Q4_KTensorData.BYTES_PER_BLOCK) + Random.nextBytes(q4kBytes) + q4kData = Q4_KBlockTensorData.fromRawBytes(Shape(25600), q4kBytes) + + // Q8_0: 800 blocks = 25600 elements + val q80Bytes = ByteArray(800 * Q8_0TensorData.BYTES_PER_BLOCK) + Random.nextBytes(q80Bytes) + q80Data = Q8_0BlockTensorData.fromRawBytes(Shape(25600), q80Bytes) + + // Ternary: 25600 elements = 6400 packed bytes + val ternaryBytes = ByteArray(6400) + Random.nextBytes(ternaryBytes) + ternaryData = Ternary2BitTensorData(Shape(25600), ternaryBytes) + } + + @Benchmark + public fun dequantQ4K(): FloatArray = (q4kData as PackedBlockStorage).toFloatArray() + + @Benchmark + public fun dequantQ8_0(): FloatArray = (q80Data as PackedBlockStorage).toFloatArray() + + @Benchmark + public fun dequantTernary(): FloatArray = (ternaryData as PackedBlockStorage).toFloatArray() +} + +// --- BufferAccessor read performance --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class BufferAccessorBenchmark { + private lateinit var accessor: ByteArrayAccessor + private val readSize = 1024 + + @Setup + public fun setup() { + val data = ByteArray(1024 * 1024) // 1 MB + Random.nextBytes(data) + accessor = ByteArrayAccessor(data) + } + + @Benchmark + public fun heapAccessor_readBytes_1KB(): ByteArray = + accessor.readBytes(512_000, readSize) + + @Benchmark + public fun heapAccessor_readByte_sequential(): Long { + var sum = 0L + for (i in 0 until readSize) { + sum += accessor.readByte(i.toLong()) + } + return sum + } +} + +// --- TensorData <-> TensorStorage bridge --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class StorageBridgeBenchmark { + private lateinit var floatTd: DenseFloatArrayTensorData + private lateinit var q4kTd: Q4_KBlockTensorData + private lateinit var floatStorage: TensorStorage + private lateinit var q4kStorage: TensorStorage + + @Setup + public fun setup() { + floatTd = DenseFloatArrayTensorData(Shape(1024), FloatArray(1024) { it.toFloat() }) + q4kTd = Q4_KBlockTensorData.fromRawBytes(Shape(256), ByteArray(144)) + + floatStorage = TensorStorageFactory.fromTensorData(floatTd) + q4kStorage = TensorStorageFactory.fromTensorData(q4kTd) + } + + @Benchmark + public fun floatTensorData_toStorage(): TensorStorage = + TensorStorageFactory.fromTensorData(floatTd) + + @Benchmark + public fun q4kTensorData_toStorage(): TensorStorage = + TensorStorageFactory.fromTensorData(q4kTd) + + @Benchmark + public fun storage_toTensorData_float(): TensorData = + TensorStorageFactory.toTensorData(floatStorage) + + @Benchmark + public fun storage_toTensorData_q4k(): TensorData = + TensorStorageFactory.toTensorData(q4kStorage) +} diff --git a/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/TurboQuantBenchmarks.kt b/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/TurboQuantBenchmarks.kt new file mode 100644 index 00000000..3254432d --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/TurboQuantBenchmarks.kt @@ -0,0 +1,244 @@ +package sk.ainet.lang.tensor + +import kotlinx.benchmark.* +import sk.ainet.lang.tensor.ops.turboquant.* +import sk.ainet.lang.tensor.storage.* +import kotlin.random.Random + +/** + * JMH benchmarks for TurboQuant KV-cache compression. + * + * Measures encode/decode throughput, compression ratio, and accuracy + * for different TurboQuant configurations. + * + * Run: ./gradlew :skainet-lang:skainet-lang-core:jvmBenchmark + */ + +// --- Encode throughput --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class TurboQuantEncodeBenchmark { + private lateinit var vector128: FloatArray + private lateinit var vector256: FloatArray + private lateinit var vector512: FloatArray + private lateinit var config4Bit: TurboQuantConfig + private lateinit var config3Bit: TurboQuantConfig + private lateinit var config8Bit: TurboQuantConfig + private lateinit var configQjl: TurboQuantConfig + + @Setup + public fun setup() { + val rng = Random(42) + vector128 = FloatArray(128) { rng.nextFloat() * 2 - 1 } + vector256 = FloatArray(256) { rng.nextFloat() * 2 - 1 } + vector512 = FloatArray(512) { rng.nextFloat() * 2 - 1 } + config4Bit = TurboQuantConfig.polarOnly(bits = 4, seed = 42) + config3Bit = TurboQuantConfig.polarOnly(bits = 3, seed = 42) + config8Bit = TurboQuantConfig.polarOnly(bits = 8, seed = 42) + configQjl = TurboQuantConfig.polarPlusQjl(bits = 4, residualBits = 1, seed = 42) + } + + @Benchmark + public fun encode_4bit_128d(): TurboQuantBlock = + TurboQuantCodec.encode(vector128, config4Bit) + + @Benchmark + public fun encode_4bit_256d(): TurboQuantBlock = + TurboQuantCodec.encode(vector256, config4Bit) + + @Benchmark + public fun encode_3bit_128d(): TurboQuantBlock = + TurboQuantCodec.encode(vector128, config3Bit) + + @Benchmark + public fun encode_8bit_128d(): TurboQuantBlock = + TurboQuantCodec.encode(vector128, config8Bit) + + @Benchmark + public fun encode_4bit_qjl_128d(): TurboQuantBlock = + TurboQuantCodec.encode(vector128, configQjl) +} + +// --- Decode throughput --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class TurboQuantDecodeBenchmark { + private lateinit var block4Bit128: TurboQuantBlock + private lateinit var block4Bit256: TurboQuantBlock + private lateinit var block3Bit128: TurboQuantBlock + private lateinit var block8Bit128: TurboQuantBlock + private lateinit var blockQjl128: TurboQuantBlock + + @Setup + public fun setup() { + val rng = Random(42) + val v128 = FloatArray(128) { rng.nextFloat() * 2 - 1 } + val v256 = FloatArray(256) { rng.nextFloat() * 2 - 1 } + + block4Bit128 = TurboQuantCodec.encode(v128, TurboQuantConfig.polarOnly(bits = 4, seed = 42)) + block4Bit256 = TurboQuantCodec.encode(v256, TurboQuantConfig.polarOnly(bits = 4, seed = 42)) + block3Bit128 = TurboQuantCodec.encode(v128, TurboQuantConfig.polarOnly(bits = 3, seed = 42)) + block8Bit128 = TurboQuantCodec.encode(v128, TurboQuantConfig.polarOnly(bits = 8, seed = 42)) + blockQjl128 = TurboQuantCodec.encode(v128, TurboQuantConfig.polarPlusQjl(bits = 4, seed = 42)) + } + + @Benchmark + public fun decode_4bit_128d(): FloatArray = + TurboQuantCodec.decode(block4Bit128) + + @Benchmark + public fun decode_4bit_256d(): FloatArray = + TurboQuantCodec.decode(block4Bit256) + + @Benchmark + public fun decode_3bit_128d(): FloatArray = + TurboQuantCodec.decode(block3Bit128) + + @Benchmark + public fun decode_8bit_128d(): FloatArray = + TurboQuantCodec.decode(block8Bit128) + + @Benchmark + public fun decode_4bit_qjl_128d(): FloatArray = + TurboQuantCodec.decode(blockQjl128) +} + +// --- Bit-packing throughput --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class BitPackerBenchmark { + private lateinit var codes128: ByteArray + private lateinit var codes1024: ByteArray + private lateinit var packed4Bit: ByteArray + private lateinit var packed2Bit: ByteArray + + @Setup + public fun setup() { + codes128 = ByteArray(128) { (it % 7 - 3).toByte() } + codes1024 = ByteArray(1024) { (it % 7 - 3).toByte() } + packed4Bit = BitPacker.pack(codes1024, 4) + packed2Bit = BitPacker.pack(ByteArray(1024) { (it % 3 - 1).toByte() }, 2) + } + + @Benchmark + public fun pack_4bit_128(): ByteArray = BitPacker.pack(codes128, 4) + + @Benchmark + public fun pack_4bit_1024(): ByteArray = BitPacker.pack(codes1024, 4) + + @Benchmark + public fun unpack_4bit_1024(): ByteArray = BitPacker.unpack(packed4Bit, 1024, 4) + + @Benchmark + public fun pack_2bit_1024(): ByteArray = BitPacker.pack(ByteArray(1024) { (it % 3 - 1).toByte() }, 2) + + @Benchmark + public fun unpack_2bit_1024(): ByteArray = BitPacker.unpack(packed2Bit, 1024, 2) +} + +// --- Random rotation throughput --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class RandomRotationBenchmark { + private lateinit var vector128: FloatArray + private lateinit var vector256: FloatArray + + @Setup + public fun setup() { + val rng = Random(42) + vector128 = FloatArray(128) { rng.nextFloat() * 2 - 1 } + vector256 = FloatArray(256) { rng.nextFloat() * 2 - 1 } + } + + @Benchmark + public fun rotate_128d(): FloatArray { + val v = vector128.copyOf() + RandomRotation.rotate(v, 42) + return v + } + + @Benchmark + public fun rotate_256d(): FloatArray { + val v = vector256.copyOf() + RandomRotation.rotate(v, 42) + return v + } + + @Benchmark + public fun rotateInverse_128d(): FloatArray { + val v = vector128.copyOf() + RandomRotation.rotate(v, 42) + RandomRotation.inverseRotate(v, 42) + return v + } +} + +// --- KV cache throughput --- + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MICROSECONDS) +public open class TurboQuantKvCacheBenchmark { + private lateinit var denseStore: DefaultKvCacheStore + private lateinit var turboStore: TurboQuantKvCacheStore + private lateinit var keyProjection: FloatArray + private lateinit var valueProjection: FloatArray + + @Setup + public fun setup() { + val rng = Random(42) + val numHeads = 8 + val headDim = 128 + val maxSeqLen = 256 + + denseStore = DefaultKvCacheStore( + KvCacheConfig(numLayers = 1, numHeads = numHeads, headDim = headDim, maxSeqLen = maxSeqLen) + ) + turboStore = TurboQuantKvCacheStore( + KvCacheConfig( + numLayers = 1, numHeads = numHeads, headDim = headDim, maxSeqLen = maxSeqLen, + keyEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 4), + valueEncoding = TensorEncoding.TurboQuantPolar(bitsPerElement = 4) + ), + keyConfig = TurboQuantConfig.polarOnly(bits = 4), + valueConfig = TurboQuantConfig.polarOnly(bits = 4) + ) + + keyProjection = FloatArray(numHeads * headDim) { rng.nextFloat() * 2 - 1 } + valueProjection = FloatArray(numHeads * headDim) { rng.nextFloat() * 2 - 1 } + } + + @Benchmark + public fun appendToken_dense() { + denseStore.clear() + denseStore.appendToken(0, keyProjection, valueProjection) + } + + @Benchmark + public fun appendToken_turbo4bit() { + turboStore.clear() + turboStore.appendToken(0, keyProjection, valueProjection) + } + + @Benchmark + public fun readKeys_dense_16tokens() { + denseStore.clear() + for (i in 0 until 16) denseStore.appendToken(0, keyProjection, valueProjection) + denseStore.readKeys(0) + } + + @Benchmark + public fun readKeys_turbo4bit_16tokens() { + turboStore.clear() + for (i in 0 until 16) turboStore.appendToken(0, keyProjection, valueProjection) + turboStore.readKeys(0) + } +}