Skip to content

GPU fixes#32

Open
rsenne wants to merge 19 commits into
mainfrom
GPU_Fixes_
Open

GPU fixes#32
rsenne wants to merge 19 commits into
mainfrom
GPU_Fixes_

Conversation

@rsenne

@rsenne rsenne commented May 9, 2026

Copy link
Copy Markdown
Owner

This branch lands the GPU + AD-backend overhaul on top of main. The big themes:

Modular AD via DifferentiationInterface

  1. Make DifferentiationInterface the unified entry point into AD; Enzyme, Mooncake, and ForwardDiff are now [weakdeps] with extensions, not hard deps.
  2. Remove the old hard-coded Enzyme machinery (DEER.DEFAULT_BACKEND etc.) — backend is now a required kwarg on ParallelMALASampler, with the user choosing the AD package they want loaded.
  3. New DEER._hvp_strategy(backend) dispatch picks forward_on_grad (Enzyme/ForwardDiff) vs reverse_on_grad (Mooncake/Zygote) for the AD-HVP fallback. Forward-on-grad is the only AD-HVP path that's reliable on GPU.

EnzymeExt: GPU-safe Enzyme rules

  1. New ext/EnzymeExt.jl (~280 lines) defines native Enzyme forward / augmented_primal / reverse rules for pmcmc_matmul, pmcmc_dot, pmcmc_dotsum. These wrappers are treated as opaque by Enzyme's IR rewriter, sidestepping the gc-transition aborts Enzyme hits when lowering cuBLAS / cuMemcpyDtoHAsync_v2 bundles inside *(::CuArray, ::CuArray).
  2. DEER._hvp_forward_backend / _hvp_closure_backend overloads pin Enzyme.Forward + set_runtime_activity + function_annotation=Const for plain AutoEnzyme(). Runtime activity is load-bearing for composed pmcmc_matmul(transpose(X), pmcmc_matmul(X, β)) calls.
  3. Forward HVP now routes through forward-mode Enzyme with runtime activity (instead of forward-over-reverse, which crashes on MvNormal / Dirichlet log-pdfs).

Tests

  1. New test/test-GPU-AD-HVP.jl and test/test-GPU-Performance.jl covering the GPU AD-HVP path.
  2. New test/test-Owned-Matmul.jl exercises the custom Enzyme rules on pmcmc_matmul / pmcmc_dot / pmcmc_dotsum.
  3. Existing tests updated to pass an explicit backend=ADTypes.AutoEnzyme() (now required).

Needed Changes Prior to Merging

  1. Expand documentation to include a limitations section
  2. Expand documentation to explain new AD choices
  3. Add worked GPU examples as indicated in Add more worked examples #28
  4. Update changelog

Resolves #29 and provides a workaround to #25

@codecov

codecov Bot commented May 9, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 21.68675% with 130 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.06%. Comparing base (9d3b516) to head (9479839).

Files with missing lines Patch % Lines
ext/EnzymeExt.jl 3.96% 97 Missing ⚠️
src/DEER/DEER.jl 36.53% 33 Missing ⚠️

❌ Your patch check has failed because the patch coverage (21.68%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (80.06%) is below the target coverage (90.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #32      +/-   ##
==========================================
- Coverage   88.65%   80.06%   -8.59%     
==========================================
  Files           6        8       +2     
  Lines        1040     1164     +124     
==========================================
+ Hits          922      932      +10     
- Misses        118      232     +114     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gdalle

gdalle commented May 10, 2026

Copy link
Copy Markdown
Collaborator

Hey @rsenne, need a review here?

@rsenne

rsenne commented May 10, 2026

Copy link
Copy Markdown
Owner Author

Hi @gdalle yes I would love that. This is my first pass on this and it would be much appreciated!

Comment thread ext/EnzymeExt.jl
Comment thread src/DEER/DEER.jl
Comment on lines +95 to +97
return DI.prepare_pushforward(
f, _hvp_forward_backend(backend), x_template, (v_template,); strict=Val(false)
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use DI.hvp directly here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is because you need a batched gradient, you may be interested in JuliaDiff/DifferentiationInterface.jl#991

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you can already batch a small number of tangents by passing a tuple though

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two reasons:

  1. The lack of batching (happy to help tackle this!)
  2. Perhaps, more importantly, every second order method I've tried breaks

Here are two MRWE if interested

#=
Direct `DI.hvp(logp, ...)` fails for every Mooncake-based config on GPU.
A single Mooncake reverse pass over the user's `gradlogp` succeeds.

Target (same shape as `test/test-GPU-AD-HVP.jl`):

  logp(β)     = -0.5 * (||β||^2 + ||Xβ||^2 / N)
  gradlogp(β) = -(β + Xᵀ X β / N)
  Hv         = -v - Xᵀ X v / N

Run from repo root:

  julia --project=test dev/di_hvp_gpu_mwe.jl
=#

push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))

using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Mooncake
import ForwardDiff
import CUDA
import LinearAlgebra: dot
using Random

CUDA.functional() || error("requires functional CUDA device")

function _logp_single(β, X)
    Xβ = pmcmc_matmul(X, β)
    N = oftype(zero(eltype(β)), size(X, 1))
    -oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end

function _gradlogp_single(β, X)
    Xβ = pmcmc_matmul(X, β)
    N  = oftype(zero(eltype(β)), size(X, 1))
    Y  = pmcmc_matmul(transpose(X), Xβ)
    Y  = Y ./ N
    Y  = Y .+ β
    return -Y
end

const D    = 20
const Ndat = 64

rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)

β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)

logp(β)     = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)

ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])

println()
println("== reverse-on-grad (Mooncake gradient of dot ∘ gradlogp) ==")
try
    closure = β -> dot(gradlogp(β), v_gpu)
    g     = DI.gradient(closure, AutoMooncake(; config=nothing), β_gpu)
    g_vec = g isa Tuple ? first(g) : g
    Hv    = Array(g_vec)
    println("  Hv[1:3] = ", Hv[1:3])
    println("  matches: ", isapprox(Hv, ref; atol=1f-3, rtol=1f-2))
catch e
    println("  FAILED: ", first(sprint(showerror, e), 600))
end

println()
configs = [
    "AutoMooncake (DI default SecondOrder)"      => AutoMooncake(; config=nothing),
    "SecondOrder(AutoForwardDiff, AutoMooncake)" => SecondOrder(AutoForwardDiff(), AutoMooncake(; config=nothing)),
    "SecondOrder(AutoMooncake, AutoForwardDiff)" => SecondOrder(AutoMooncake(; config=nothing), AutoForwardDiff()),
]

for (lbl, b) in configs
    print("== $lbl ==\n  ")
    try
        prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
        Hv   = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
        vec_ = Hv isa Tuple ? first(Hv) : Hv
        host = Array(vec_)
        ok   = isapprox(host, ref; atol=1f-3, rtol=1f-2)
        println("Hv[1:3]: ", host[1:3], "  matches: ", ok)
    catch e
        println("FAILED: ", first(sprint(showerror, e), 600))
    end
end
#=
Direct `DI.hvp(logp, AutoEnzyme(...))` fails for every config on GPU.
A forward-mode Enzyme pushforward of the user's `gradlogp` succeeds.

Same target as `dev/di_hvp_gpu_mwe.jl` / `test/test-GPU-AD-HVP.jl`:

  logp(β)     = -0.5 * (||β||^2 + ||Xβ||^2 / N)
  gradlogp(β) = -(β + Xᵀ X β / N)
  Hv         = -v - Xᵀ X v / N

Five Enzyme variants tested. Three hard-abort the Julia process during
Enzyme compilation and therefore cannot share a script with the rest:

  AutoEnzyme()                                 — hard abort during compile
  AutoEnzyme(mode=Reverse)                     — hard abort during compile
  SecondOrder(AutoEnzyme(Forward),
              AutoEnzyme(Reverse))             — hard abort during compile

To reproduce the aborts run one variant at a time. The two below throw
catchable exceptions and run cleanly in the same process.

Run from repo root:

  julia --project=test dev/di_hvp_gpu_enzyme_mwe.jl
=#

push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))

using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Enzyme
import CUDA
using Random

CUDA.functional() || error("requires functional CUDA device")

function _logp_single(β, X)
    Xβ = pmcmc_matmul(X, β)
    N = oftype(zero(eltype(β)), size(X, 1))
    -oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end

function _gradlogp_single(β, X)
    Xβ = pmcmc_matmul(X, β)
    N  = oftype(zero(eltype(β)), size(X, 1))
    Y  = pmcmc_matmul(transpose(X), Xβ)
    Y  = Y ./ N
    Y  = Y .+ β
    return -Y
end

const D    = 20
const Ndat = 64

rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)

β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)

logp(β)     = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)

ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])

println()
println("== forward-on-grad (Enzyme.Forward pushforward of gradlogp) ==")
try
    be    = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)
    prep  = DI.prepare_pushforward(gradlogp, be, β_gpu, (v_gpu,); strict=Val(false))
    Hv    = DI.pushforward(gradlogp, prep, be, β_gpu, (v_gpu,))
    vec_  = Hv isa Tuple ? first(Hv) : Hv
    host  = Array(vec_)
    println("  Hv[1:3] = ", host[1:3])
    println("  matches: ", isapprox(host, ref; atol=1f-3, rtol=1f-2))
catch e
    println("  FAILED: ", first(sprint(showerror, e), 600))
end

println()
configs = [
    "AutoEnzyme(mode=Forward)" =>
        AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
    "SecondOrder(AutoEnzyme(Reverse), AutoEnzyme(Forward))" =>
        SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Const),
                    AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)),
]

for (lbl, b) in configs
    print("== $lbl ==\n  ")
    try
        prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
        Hv   = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
        vec_ = Hv isa Tuple ? first(Hv) : Hv
        host = Array(vec_)
        ok   = isapprox(host, ref; atol=1f-3, rtol=1f-2)
        println("Hv[1:3]: ", host[1:3], "  matches: ", ok)
    catch e
        println("FAILED: ", first(sprint(showerror, e), 800))
    end
end

Totally and completely possible I am doing something fundamentally wrong--in which case please correct me

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try the native hvp APIs of both backends to see whether the problem lies in DI or deeper?
For Mooncake, I would also try SecondOrder(AutoForwardMooncake(), AutoMooncake())

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#=
Test HVP on GPU across the backends

ENZYME
  enz-hvp-native   Enzyme.hvp(logp, x, v)                     
  enz-hvp-so       DI.hvp, SecondOrder(Enzyme.Fwd, Enzyme.Rev)  
  enz-hvp-pf       DI.pushforward(gradlogp, Enzyme.Fwd)       

  mc-hvp-default   DI.hvp(logp, AutoMooncake())              
  mc-hvp-so        DI.hvp, SecondOrder(AutoMooncakeForward, AutoMooncake)
  mc-hvp-pf        DI.pushforward(gradlogp, AutoMooncakeForward)  
=#

const BLOCK = isempty(ARGS) ? "enz-hvp-native" : ARGS[1]

using CUDA
using ADTypes
using ParallelMCMC                 
const DI = ParallelMCMC.DEER.DI 
using Random

startswith(BLOCK, "enz") && (@eval import Enzyme)
startswith(BLOCK, "mc")  && (@eval import Mooncake)

CUDA.functional() || error("requires functional CUDA device")

logp(x)     = -0.5f0 * sum(abs2, x)
gradlogp(x) = -x

rng   = MersenneTwister(20251231)
x_gpu = CUDA.CuArray(randn(rng, Float32, 8))
v_gpu = CUDA.CuArray(ones(Float32, 8))
ref   = -Array(v_gpu)                       # analytic Hv = -v

report(Hv) = (host = Array(Hv isa Tuple ? first(Hv) : Hv);
              println("  Hv[1:3] = ", host[1:3],
                      "   matches -v: ", isapprox(host, ref; atol=1f-4, rtol=1f-3)))

println("\n== $BLOCK ==")
flush(stdout); flush(stderr)
try
    if BLOCK == "enz-hvp-native"
        report(Enzyme.hvp(logp, x_gpu, v_gpu))

    elseif BLOCK == "enz-hvp-so"
        be = DI.SecondOrder(
            AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
            AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
                         function_annotation=Enzyme.Const),
        )
        report(DI.hvp(logp, be, x_gpu, (v_gpu,)))

    elseif BLOCK == "enz-hvp-pf"
        be = AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward),
                          function_annotation=Enzyme.Const)
        report(DI.pushforward(gradlogp, be, x_gpu, (v_gpu,)))

    elseif BLOCK == "mc-hvp-default"
        report(DI.hvp(logp, AutoMooncake(; config=nothing), x_gpu, (v_gpu,)))

    elseif BLOCK == "mc-hvp-so"
        if !isdefined(ADTypes, :AutoMooncakeForward)
            println("  ADTypes.AutoMooncakeForward not available in this version")
        else
            be = DI.SecondOrder(ADTypes.AutoMooncakeForward(; config=nothing),
                                AutoMooncake(; config=nothing))
            report(DI.hvp(logp, be, x_gpu, (v_gpu,)))
        end

    elseif BLOCK == "mc-hvp-pf"
        if !isdefined(ADTypes, :AutoMooncakeForward)
            println("  ADTypes.AutoMooncakeForward not available in this version")
        else
            be = ADTypes.AutoMooncakeForward(; config=nothing)
            report(DI.pushforward(gradlogp, be, x_gpu, (v_gpu,)))
        end

    else
        println("  unknown block: $BLOCK")
    end
    println("  >>> $BLOCK SUCCEEDED")
catch e
    println("  >>> $BLOCK THREW (catchable):")
    showerror(stdout, e)
    println()
end
flush(stdout); flush(stderr)

using this script---everything but the push forwards fails. So these are upstream bugs it looks like

Comment thread src/DEER/DEER.jl
Comment thread src/DEER/DEER.jl
Comment thread src/DEER/DEER.jl Outdated
Comment thread src/interface.jl
Comment thread src/ParallelMCMC.jl Outdated
Comment thread Project.toml
@rsenne rsenne requested a review from gdalle May 18, 2026 13:41
@rsenne

rsenne commented May 18, 2026

Copy link
Copy Markdown
Owner Author

hey @gdalle I addressed the two addressable point (e.g., type instability and nixing symbols) thoughts?

Also, included 2 MWRE above. If I'm not crazy--I can open issues for these on the respective repos though I'm not confident yet till someone who knows better than i says so. Also, happy to help tackle the linked DI issue for batching--it seems reasonably approachable?

Let me know what you think!

@rsenne

rsenne commented May 21, 2026

Copy link
Copy Markdown
Owner Author

Hi @wsmoses -- could I get your review on the Enzyme extension I put together here? The basic point of the extension was to get some basic LinAlg working on the GPU, but I have concerns I may have overengineered here. So, before I commit any of these changes I want to make sure what I have implemented makes sense. Thanks!

Comment thread ext/EnzymeExt.jl
Comment thread ext/EnzymeExt.jl
Comment thread ext/EnzymeExt.jl
@rsenne rsenne marked this pull request as ready for review May 24, 2026 19:18
Copilot AI review requested due to automatic review settings May 24, 2026 19:18

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR overhauls the GPU + AD backend integration by routing AD through DifferentiationInterface and making Enzyme/Mooncake/ForwardDiff optional weak dependencies (via extensions). It also adds GPU-focused Enzyme rules (via owned wrapper functions) to avoid Enzyme failures on common CUDA kernels, and expands the test/docs surface for GPU execution and AD-HVP fallbacks.

Changes:

  • Make DifferentiationInterface the unified AD entry point and require an explicit backend for ParallelMALASampler.
  • Add Enzyme extension rules for owned wrappers (pmcmc_matmul / pmcmc_dot / pmcmc_dotsum) to make GPU AD-HVP paths Enzyme-safe.
  • Add GPU AD-HVP/performance tests and new GPU documentation (limitations + worked example).

Reviewed changes

Copilot reviewed 33 out of 35 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
Project.toml Moves AD packages to [weakdeps] and wires extensions/compat for modular AD backends.
src/ParallelMCMC.jl Adds and exports owned wrapper functions for matmul/dot/reduction to support backend-specific AD rules.
src/interface.jl Makes backend a required keyword for ParallelMALASampler and updates DEER-rec builder to use strategy-dispatched AD-HVP factories.
src/DEER/DEER.jl Removes hard-coded Enzyme defaults; adds HVP strategy dispatch (forward-on-grad vs reverse-on-grad) and backend normalization hooks.
src/DEER/DEERScan.jl Minor comment/style update in scan implementation.
src/MALA/MALA.jl Tweaks internal quadratic forms and refactors JVP scalar computations.
ext/EnzymeExt.jl Adds native Enzyme rules for pmcmc_* wrappers and backend normalization for Enzyme HVP paths.
ext/DynamicPPLExt.jl Updates Turing/DynamicPPL DensityModel convenience constructor defaulting to AutoForwardDiff().
test/test-Owned-Matmul.jl New tests validating owned wrappers and Enzyme rules (forward + reverse + Const-arg branches).
test/test-GPU-AD-HVP.jl New tests for GPU AD-HVP fallback across Enzyme/Mooncake/Zygote backends.
test/test-GPU-Performance.jl New GPU-vs-CPU performance regression sanity test.
test/test-Turing-Integration.jl Expands DynamicPPL/Turing integration tests and passes explicit sampler backend.
test/test-DEER-Interface.jl Updates DEER interface tests for required sampler backend.
test/test-DEER-Turing-Logistic.jl Updates logistic regression tests for explicit backend and adjusts tolerances/params for stability.
test/test-Deer-vs-MALA.jl Updates internal AD-HVP calls to use explicit AutoEnzyme() backend.
test/test-Jacobian-Estimator.jl Updates tests to avoid removed DEFAULT_BACKEND and use explicit backend.
test/test-MALA-Kernel.jl Formatting-only refactors in MALA kernel tests.
test/test-GPU-MALA.jl Improves determinism (CUDA.seed!) and adjusts stationary mean tolerance.
test/test-Adaptive-MALA.jl Comment formatting change.
test/test-Code-Quality.jl Enables Aqua + JET checks in the test suite.
docs/src/10-getting-started.md Updates getting-started narrative and points to the new GPU execution page.
docs/src/15-gpu.md New GPU execution page covering limitations, backend choices, and a worked logistic regression example.
docs/src/assets/make_julia_deer_gif.jl Formatting-only refactor in docs asset script.
benchmarks/ParallelMCMCBenchmarks/Project.toml Adds Enzyme dependency and normalizes [sources] layout.
benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl Updates benchmark suite to pass explicit sampler backend.
benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl Comment formatting tweaks in benchmark model.
benchmarks/ParallelMCMCBenchmarks/src/models/bayes_linreg.jl Comment formatting tweaks in benchmark model.
benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl Updates profiling script to pass explicit backend in DEER rec + sampler.
benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl Updates helper default backend to AutoEnzyme().
benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl Updates helper default backend to AutoEnzyme().
benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl Updates DEER sampler construction to pass explicit backend.
benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl Updates GPU DEER benchmarking to pass explicit backend.
benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl Formatting-only refactor.
benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl Formatting-only refactor.
.gitignore Ignores local debugging script directories (/scripts, /dev).
Comments suppressed due to low confidence (1)

ext/DynamicPPLExt.jl:18

  • The docstring says only DynamicPPL must be loaded, but the extension is configured to load only when DynamicPPL, ForwardDiff, and LogDensityProblems are all loaded (Project.toml). Either update the docstring to reflect the actual requirement (and mention that ForwardDiff is needed for the default AutoForwardDiff()), or loosen the extension deps and add a runtime check/error when AutoForwardDiff() is selected without ForwardDiff loaded.
Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a
`DensityModel`, automatically extracting parameter names and wiring up gradient
computation via DynamicPPL's `adtype` interface.

Requires `DynamicPPL` to be loaded.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread test/test-Owned-Matmul.jl Outdated
Comment thread test/test-GPU-Performance.jl
Comment thread docs/src/10-getting-started.md
Comment thread src/ParallelMCMC.jl
@rsenne

rsenne commented May 24, 2026

Copy link
Copy Markdown
Owner Author

Hi @penelopeysm -- could I get a quick review on some of the statements I wrote in the docs about DynamicPPL? I want to make it clear that ParallelMCMC on the GPU does not play nicely with Turing because, as far as I'm aware, DynamicPPl is not set up at the moment to handle GPU bound models. I think this is in the works, but you'd know better than I. If its incorrect in anyway please lmk! Thanks!

@penelopeysm

Copy link
Copy Markdown
Contributor

That is reasonable to me, fwiw at one point when I was still on the project I was told that GPU compatibility was not a priority

@rsenne

rsenne commented May 24, 2026

Copy link
Copy Markdown
Owner Author

Hmm okay. all the more reason to integrate FlexiChains with first-class support I guess

@rsenne

rsenne commented May 28, 2026

Copy link
Copy Markdown
Owner Author

@gdalle can I get a last quick set of eyes on this before I merge in? It seems that the Enzyme LinAlg + GPU is a real limitation. My thought is to use the current I have, and once I get an okay to merge upstream, I'll strip it out. Just have a few people pinging me to use the package so wanna get this working. Thanks!

@gdalle

gdalle commented May 28, 2026

Copy link
Copy Markdown
Collaborator

I'll review tomorrow, sorry for the delay!

Comment thread docs/src/15-gpu.md Outdated
Comment thread test/test-Owned-Matmul.jl Outdated
Comment thread test/test-Owned-Matmul.jl Outdated
Comment thread src/interface.jl Outdated
Comment thread ext/EnzymeExt.jl Outdated
Comment thread src/DEER/DEER.jl Outdated
Comment thread src/DEER/DEER.jl Outdated
Comment thread src/DEER/DEER.jl Outdated
Comment on lines +203 to +206
Default is `ForwardOnGrad()` since most DI backends support forward mode.
We override to `ReverseOnGrad()` for reverse-only backends. AutoEnzyme stays
on forward because Enzyme.Forward is robust on CuArrays once the matmul is
wrapped.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it depends whether you want to consider "backends" (and thus separate AutoMooncake from AutoMooncakeForward, AutoEnzyme(; mode=Forward) from its reverse counterpart) or "packages", in which case Enzyme and Mooncake are bidirectional

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My impulse is to design around "backends". Thoughts?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes more sense based on the existing ADTypes interface, but you might need to use DI.SecondOrder

@rsenne rsenne Jun 3, 2026

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given SecondOrder is failing I think this will be a future issue to address. So for now, I think backends will be whats considered. though, I'll edit this comment to recognize this split

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you paste the error you get with SecondOrder and an MWE?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pasted a script above--let me refashion it into just the specific SecondOrder case and I'll paste the error

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using CUDA
using Mooncake
using ParallelMCMC
const DI = ParallelMCMC.DEER.DI
using Random

logp(x) = -0.5f0 * sum(abs2, x)

rng = MersenneTwister(20251231)
x_gpu = CUDA.CuArray(randn(rng, Float32, 8))
v_gpu = CUDA.CuArray(ones(Float32, 8))
ref = -Array(v_gpu)

be = DI.SecondOrder(DI.AutoMooncakeForward(), DI.AutoMooncake())
DI.hvp(logp, be, x_gpu, (v_gpu,))
ERROR: Mooncake.IntrinsicsWrappers.MissingIntrinsicWrapperException("Unable to translate the intrinsic Val{Core.Intrinsics.atomic_pointerref}() into a regular Julia function. Please see github.com/chalk-lab/Mooncake.jl/issues/208 for more discussion.")
Stacktrace:
  [1] translate(f::Val{Core.Intrinsics.atomic_pointerref})
    @ Mooncake.IntrinsicsWrappers /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/rules/builtins.jl:128
  [2] lift_intrinsic(::Core.IntrinsicFunction, ::Core.SSAValue, ::QuoteNode, ::Vararg{Any})
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/ir_normalisation.jl:309
  [3] intrinsic_to_function
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/ir_normalisation.jl:300 [inlined]
  [4] normalise!(ir::Compiler.IRCode, spnames::Vector{Symbol})
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/ir_normalisation.jl:35
  [5] generate_dual_ir(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, do_inline::Bool, do_optimize::Bool)
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:196
  [6] build_frule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Any; debug_mode::Bool, silence_debug_messages::Bool, skip_world_age_check::Bool)
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:98
  [7] _build_rule!(rule::Mooncake.LazyFRule{…}, args::Tuple{…})
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:532
  [8] LazyFRule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:527 [inlined]
  [9] cufunction
    @ /projectnb/ssmsvi/rsenne/.julia/packages/CUDA/bjncr/src/compiler/execution.jl:365 [inlined]
 [10] (::Tuple{…})(_2::Mooncake.Dual{…}, _3::Mooncake.Dual{…}, _4::Mooncake.Dual{…}, _5::Mooncake.Dual{…}, _6::Mooncake.Dual{…})
    @ Base.Experimental ./<missing>:0
 [11] (::MistyClosures.MistyClosure{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
    @ MistyClosures /projectnb/ssmsvi/rsenne/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
 [12] (::Mooncake.DerivedFRule{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:122
 [13] __call_rule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/utils.jl:639 [inlined]
 [14] (::Mooncake.DynamicFRule{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:582
 [15] macro expansion
    @ /projectnb/ssmsvi/rsenne/.julia/packages/CUDA/bjncr/src/compiler/execution.jl:112 [inlined]
 [16] #_#6
    @ /projectnb/ssmsvi/rsenne/.julia/packages/CUDA/bjncr/src/CUDAKernels.jl:129 [inlined]
 [17] (::Tuple{…})(_2::Mooncake.Dual{…}, _3::Mooncake.Dual{…}, _4::Mooncake.Dual{…}, _5::Mooncake.Dual{…}, _6::Mooncake.Dual)
    @ Base.Experimental ./<missing>:0
 [18] (::MistyClosures.MistyClosure{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
    @ MistyClosures /projectnb/ssmsvi/rsenne/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
 [19] DerivedFRule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:122 [inlined]
 [20] __call_rule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/utils.jl:639 [inlined]
 [21] _build_rule!(rule::Mooncake.LazyFRule{…}, args::Tuple{…})
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:533
 [22] LazyFRule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:527 [inlined]
 [23] value_and_gradient
    @ /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:133 [inlined]
 [24] (::Tuple{…})(_2::Mooncake.Dual{…}, _3::Mooncake.Dual{…}, _4::Mooncake.Dual{…}, _5::Mooncake.Dual{…}, _6::Mooncake.Dual{…}, _7::Mooncake.Dual{…})
    @ Base.Experimental ./<missing>:0
 [25] (::MistyClosures.MistyClosure{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
    @ MistyClosures /projectnb/ssmsvi/rsenne/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
 [26] DerivedFRule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:122 [inlined]
 [27] __call_rule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/utils.jl:639 [inlined]
 [28] _build_rule!(rule::Mooncake.LazyFRule{…}, args::Tuple{…})
    @ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:533
 [29] LazyFRule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:527 [inlined]
 [30] shuffled_gradient
    @ /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/src/first_order/gradient.jl:171 [inlined]
 [31] (::Tuple{…})(_2::Mooncake.Dual{…}, _3::Mooncake.Dual{…}, _4::Mooncake.Dual{…}, _5::Mooncake.Dual{…}, _6::Mooncake.Dual{…}, _7::Mooncake.Dual{…}, _8::Mooncake.Dual{…})
    @ Base.Experimental ./<missing>:0
 [32] (::MistyClosures.MistyClosure{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
    @ MistyClosures /projectnb/ssmsvi/rsenne/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
 [33] DerivedFRule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:122 [inlined]
 [34] __call_rule
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/utils.jl:639 [inlined]
 [35] value_and_derivative!!
    @ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interface.jl:2128 [inlined]
 [36] #17
    @ /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl:37 [inlined]
 [37] map
    @ ./tuple.jl:358 [inlined]
 [38] value_and_pushforward(::typeof(DifferentiationInterface.shuffled_gradient), ::DifferentiationInterfaceMooncakeExt.MooncakeOneArgPushforwardPrep{…}, ::ADTypes.AutoMooncakeForward{…}, ::CuArray{…}, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.ConstantOrCache{…}, ::DifferentiationInterface.Constant{…}, ::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceMooncakeExt /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl:36
 [39] pushforward
    @ /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl:61 [inlined]
 [40] hvp(::typeof(logp), ::DifferentiationInterface.ForwardOverAnythingHVPPrep{…}, ::DifferentiationInterface.SecondOrder{…}, ::CuArray{…}, ::Tuple{…})
    @ DifferentiationInterface /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/src/second_order/hvp.jl:331
 [41] hvp(::typeof(logp), ::DifferentiationInterface.SecondOrder{…}, ::CuArray{…}, ::Tuple{…})
    @ DifferentiationInterface /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/src/second_order/hvp.jl:75
 [42] top-level scope
    @ /projectnb/ssmsvi/rsenne/ParallelMCMC.jl/scripts/just_second_order.jl:15
Some type information was truncated. Use `show(err)` to see complete types.

let me know if anything i did was incorrect!

@rsenne rsenne Jun 3, 2026

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also--you get the same error for Mooncake when using the native API via:

cache = Mooncake.prepare_hvp_cache(logp, x_gpu)
Mooncake.value_and_hvp!!(cache, logp, v_gpu, x_gpu)

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gentle bump @gdalle

Comment thread src/DEER/DEER.jl
Comment thread src/DEER/DEER.jl
@rsenne

rsenne commented Jun 3, 2026

Copy link
Copy Markdown
Owner Author

okay--i was able to remove strict=Val(false) and prepare_pushforward when not re-used. Other changes were mostly related to incorrect/sloppiness in some of my comments which i tightened. I think we also identified some good changes/issues with future discussion that could make this package better either usptream (e.g., matmul overides in Enzyme/Batching in DI/the forward and reverse mapping you suggested) and also locally (e.g., hvp strategy logic/how we can handle backends vs.packages). Let me know if you have additional thoughts!

@rsenne

rsenne commented Jun 13, 2026

Copy link
Copy Markdown
Owner Author

hiya @gdalle, sorry to bother you again, did you get a chance to look at this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Remove default backend (Enzyme) and make DI the main UI

5 participants