diff --git a/.github/workflows/upstream-sync.yml b/.github/workflows/upstream-sync.yml index 29d4cd71..9fd5408f 100644 --- a/.github/workflows/upstream-sync.yml +++ b/.github/workflows/upstream-sync.yml @@ -32,13 +32,15 @@ jobs: - name: Create or Update Sync Branch run: | - git checkout -B sync/upstream-latest - git reset --hard upstream/main + git checkout -B sync/upstream-latest origin/main + git checkout upstream/main -- . + git checkout origin/main -- .github/workflows + git commit -m "chore: sync with upstream/main" || true git push -f origin sync/upstream-latest - name: Create Pull Request to Main env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GH_TOKEN: ${{ secrets.SWIFTLM_PR_TOKEN || secrets.GITHUB_TOKEN }} run: | gh pr create --base main --head sync/upstream-latest \ --title "🔄 Auto-Sync: Apple Upstream Repository" \ diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index 41dccd23..0eaf07ad 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -420,6 +420,27 @@ public enum MLXFast { mlx_fast_set_prefetch_enabled(enabled) } + /// Convert an array of raw FP8 E4M3 bytes (stored as uint8 by safetensors loader) + /// to the specified floating point dtype using proper FP8 E4M3 semantics. + /// + /// MLX's safetensors loader maps `F8_E4M3` → `uint8` (raw bit patterns). + /// Use this before applying block-wise scale_inv to dequantize FP8 weights correctly. + /// + /// - Parameters: + /// - x: uint8 array containing raw FP8 E4M3 bit patterns + /// - dtype: target floating-point dtype (e.g. `.bfloat16`, `.float16`, `.float32`) + /// - stream: stream or device to evaluate on + /// - Returns: Array in the requested dtype with correctly converted FP8 values + public static func fromFp8( + _ x: MLXArray, dtype: DType = .bfloat16, stream: StreamOrDevice = .default + ) -> MLXArray { + precondition(x.dtype == .uint8, "FP8 input must be uint8 bit patterns") + precondition(dtype == .bfloat16 || dtype == .float16 || dtype == .float32, "Target dtype must be a floating point type") + var result = mlx_array_new() + mlx_from_fp8(&result, x.ctx, dtype.cmlxDtype, stream.ctx) + return MLXArray(result) + } + } /// Optimized implementation of `NN.RoPE`.