GPU fixes#32
Conversation
…ault enzyme machinery.
# Conflicts: # Project.toml # ext/DynamicPPLExt.jl # ext/LogDensityProblemsExt.jl # test/test-DEER-Turing-Logistic.jl # test/test-Turing-Integration.jl
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (21.68%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #32 +/- ##
==========================================
- Coverage 88.65% 80.06% -8.59%
==========================================
Files 6 8 +2
Lines 1040 1164 +124
==========================================
+ Hits 922 932 +10
- Misses 118 232 +114 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hey @rsenne, need a review here? |
|
Hi @gdalle yes I would love that. This is my first pass on this and it would be much appreciated! |
| return DI.prepare_pushforward( | ||
| f, _hvp_forward_backend(backend), x_template, (v_template,); strict=Val(false) | ||
| ) |
There was a problem hiding this comment.
Why not use DI.hvp directly here?
There was a problem hiding this comment.
If it is because you need a batched gradient, you may be interested in JuliaDiff/DifferentiationInterface.jl#991
There was a problem hiding this comment.
Note that you can already batch a small number of tangents by passing a tuple though
There was a problem hiding this comment.
Two reasons:
- The lack of batching (happy to help tackle this!)
- Perhaps, more importantly, every second order method I've tried breaks
Here are two MRWE if interested
#=
Direct `DI.hvp(logp, ...)` fails for every Mooncake-based config on GPU.
A single Mooncake reverse pass over the user's `gradlogp` succeeds.
Target (same shape as `test/test-GPU-AD-HVP.jl`):
logp(β) = -0.5 * (||β||^2 + ||Xβ||^2 / N)
gradlogp(β) = -(β + Xᵀ X β / N)
Hv = -v - Xᵀ X v / N
Run from repo root:
julia --project=test dev/di_hvp_gpu_mwe.jl
=#
push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))
using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Mooncake
import ForwardDiff
import CUDA
import LinearAlgebra: dot
using Random
CUDA.functional() || error("requires functional CUDA device")
function _logp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
-oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end
function _gradlogp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
Y = pmcmc_matmul(transpose(X), Xβ)
Y = Y ./ N
Y = Y .+ β
return -Y
end
const D = 20
const Ndat = 64
rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)
β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)
logp(β) = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)
ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])
println()
println("== reverse-on-grad (Mooncake gradient of dot ∘ gradlogp) ==")
try
closure = β -> dot(gradlogp(β), v_gpu)
g = DI.gradient(closure, AutoMooncake(; config=nothing), β_gpu)
g_vec = g isa Tuple ? first(g) : g
Hv = Array(g_vec)
println(" Hv[1:3] = ", Hv[1:3])
println(" matches: ", isapprox(Hv, ref; atol=1f-3, rtol=1f-2))
catch e
println(" FAILED: ", first(sprint(showerror, e), 600))
end
println()
configs = [
"AutoMooncake (DI default SecondOrder)" => AutoMooncake(; config=nothing),
"SecondOrder(AutoForwardDiff, AutoMooncake)" => SecondOrder(AutoForwardDiff(), AutoMooncake(; config=nothing)),
"SecondOrder(AutoMooncake, AutoForwardDiff)" => SecondOrder(AutoMooncake(; config=nothing), AutoForwardDiff()),
]
for (lbl, b) in configs
print("== $lbl ==\n ")
try
prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
ok = isapprox(host, ref; atol=1f-3, rtol=1f-2)
println("Hv[1:3]: ", host[1:3], " matches: ", ok)
catch e
println("FAILED: ", first(sprint(showerror, e), 600))
end
end#=
Direct `DI.hvp(logp, AutoEnzyme(...))` fails for every config on GPU.
A forward-mode Enzyme pushforward of the user's `gradlogp` succeeds.
Same target as `dev/di_hvp_gpu_mwe.jl` / `test/test-GPU-AD-HVP.jl`:
logp(β) = -0.5 * (||β||^2 + ||Xβ||^2 / N)
gradlogp(β) = -(β + Xᵀ X β / N)
Hv = -v - Xᵀ X v / N
Five Enzyme variants tested. Three hard-abort the Julia process during
Enzyme compilation and therefore cannot share a script with the rest:
AutoEnzyme() — hard abort during compile
AutoEnzyme(mode=Reverse) — hard abort during compile
SecondOrder(AutoEnzyme(Forward),
AutoEnzyme(Reverse)) — hard abort during compile
To reproduce the aborts run one variant at a time. The two below throw
catchable exceptions and run cleanly in the same process.
Run from repo root:
julia --project=test dev/di_hvp_gpu_enzyme_mwe.jl
=#
push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))
using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Enzyme
import CUDA
using Random
CUDA.functional() || error("requires functional CUDA device")
function _logp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
-oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end
function _gradlogp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
Y = pmcmc_matmul(transpose(X), Xβ)
Y = Y ./ N
Y = Y .+ β
return -Y
end
const D = 20
const Ndat = 64
rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)
β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)
logp(β) = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)
ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])
println()
println("== forward-on-grad (Enzyme.Forward pushforward of gradlogp) ==")
try
be = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)
prep = DI.prepare_pushforward(gradlogp, be, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.pushforward(gradlogp, prep, be, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
println(" Hv[1:3] = ", host[1:3])
println(" matches: ", isapprox(host, ref; atol=1f-3, rtol=1f-2))
catch e
println(" FAILED: ", first(sprint(showerror, e), 600))
end
println()
configs = [
"AutoEnzyme(mode=Forward)" =>
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
"SecondOrder(AutoEnzyme(Reverse), AutoEnzyme(Forward))" =>
SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Const),
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)),
]
for (lbl, b) in configs
print("== $lbl ==\n ")
try
prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
ok = isapprox(host, ref; atol=1f-3, rtol=1f-2)
println("Hv[1:3]: ", host[1:3], " matches: ", ok)
catch e
println("FAILED: ", first(sprint(showerror, e), 800))
end
endTotally and completely possible I am doing something fundamentally wrong--in which case please correct me
There was a problem hiding this comment.
Can you try the native hvp APIs of both backends to see whether the problem lies in DI or deeper?
For Mooncake, I would also try SecondOrder(AutoForwardMooncake(), AutoMooncake())
There was a problem hiding this comment.
#=
Test HVP on GPU across the backends
ENZYME
enz-hvp-native Enzyme.hvp(logp, x, v)
enz-hvp-so DI.hvp, SecondOrder(Enzyme.Fwd, Enzyme.Rev)
enz-hvp-pf DI.pushforward(gradlogp, Enzyme.Fwd)
mc-hvp-default DI.hvp(logp, AutoMooncake())
mc-hvp-so DI.hvp, SecondOrder(AutoMooncakeForward, AutoMooncake)
mc-hvp-pf DI.pushforward(gradlogp, AutoMooncakeForward)
=#
const BLOCK = isempty(ARGS) ? "enz-hvp-native" : ARGS[1]
using CUDA
using ADTypes
using ParallelMCMC
const DI = ParallelMCMC.DEER.DI
using Random
startswith(BLOCK, "enz") && (@eval import Enzyme)
startswith(BLOCK, "mc") && (@eval import Mooncake)
CUDA.functional() || error("requires functional CUDA device")
logp(x) = -0.5f0 * sum(abs2, x)
gradlogp(x) = -x
rng = MersenneTwister(20251231)
x_gpu = CUDA.CuArray(randn(rng, Float32, 8))
v_gpu = CUDA.CuArray(ones(Float32, 8))
ref = -Array(v_gpu) # analytic Hv = -v
report(Hv) = (host = Array(Hv isa Tuple ? first(Hv) : Hv);
println(" Hv[1:3] = ", host[1:3],
" matches -v: ", isapprox(host, ref; atol=1f-4, rtol=1f-3)))
println("\n== $BLOCK ==")
flush(stdout); flush(stderr)
try
if BLOCK == "enz-hvp-native"
report(Enzyme.hvp(logp, x_gpu, v_gpu))
elseif BLOCK == "enz-hvp-so"
be = DI.SecondOrder(
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation=Enzyme.Const),
)
report(DI.hvp(logp, be, x_gpu, (v_gpu,)))
elseif BLOCK == "enz-hvp-pf"
be = AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation=Enzyme.Const)
report(DI.pushforward(gradlogp, be, x_gpu, (v_gpu,)))
elseif BLOCK == "mc-hvp-default"
report(DI.hvp(logp, AutoMooncake(; config=nothing), x_gpu, (v_gpu,)))
elseif BLOCK == "mc-hvp-so"
if !isdefined(ADTypes, :AutoMooncakeForward)
println(" ADTypes.AutoMooncakeForward not available in this version")
else
be = DI.SecondOrder(ADTypes.AutoMooncakeForward(; config=nothing),
AutoMooncake(; config=nothing))
report(DI.hvp(logp, be, x_gpu, (v_gpu,)))
end
elseif BLOCK == "mc-hvp-pf"
if !isdefined(ADTypes, :AutoMooncakeForward)
println(" ADTypes.AutoMooncakeForward not available in this version")
else
be = ADTypes.AutoMooncakeForward(; config=nothing)
report(DI.pushforward(gradlogp, be, x_gpu, (v_gpu,)))
end
else
println(" unknown block: $BLOCK")
end
println(" >>> $BLOCK SUCCEEDED")
catch e
println(" >>> $BLOCK THREW (catchable):")
showerror(stdout, e)
println()
end
flush(stdout); flush(stderr)
using this script---everything but the push forwards fails. So these are upstream bugs it looks like
|
hey @gdalle I addressed the two addressable point (e.g., type instability and nixing symbols) thoughts? Also, included 2 MWRE above. If I'm not crazy--I can open issues for these on the respective repos though I'm not confident yet till someone who knows better than i says so. Also, happy to help tackle the linked DI issue for batching--it seems reasonably approachable? Let me know what you think! |
|
Hi @wsmoses -- could I get your review on the Enzyme extension I put together here? The basic point of the extension was to get some basic LinAlg working on the GPU, but I have concerns I may have overengineered here. So, before I commit any of these changes I want to make sure what I have implemented makes sense. Thanks! |
There was a problem hiding this comment.
Pull request overview
This PR overhauls the GPU + AD backend integration by routing AD through DifferentiationInterface and making Enzyme/Mooncake/ForwardDiff optional weak dependencies (via extensions). It also adds GPU-focused Enzyme rules (via owned wrapper functions) to avoid Enzyme failures on common CUDA kernels, and expands the test/docs surface for GPU execution and AD-HVP fallbacks.
Changes:
- Make
DifferentiationInterfacethe unified AD entry point and require an explicitbackendforParallelMALASampler. - Add Enzyme extension rules for owned wrappers (
pmcmc_matmul/pmcmc_dot/pmcmc_dotsum) to make GPU AD-HVP paths Enzyme-safe. - Add GPU AD-HVP/performance tests and new GPU documentation (limitations + worked example).
Reviewed changes
Copilot reviewed 33 out of 35 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
Project.toml |
Moves AD packages to [weakdeps] and wires extensions/compat for modular AD backends. |
src/ParallelMCMC.jl |
Adds and exports owned wrapper functions for matmul/dot/reduction to support backend-specific AD rules. |
src/interface.jl |
Makes backend a required keyword for ParallelMALASampler and updates DEER-rec builder to use strategy-dispatched AD-HVP factories. |
src/DEER/DEER.jl |
Removes hard-coded Enzyme defaults; adds HVP strategy dispatch (forward-on-grad vs reverse-on-grad) and backend normalization hooks. |
src/DEER/DEERScan.jl |
Minor comment/style update in scan implementation. |
src/MALA/MALA.jl |
Tweaks internal quadratic forms and refactors JVP scalar computations. |
ext/EnzymeExt.jl |
Adds native Enzyme rules for pmcmc_* wrappers and backend normalization for Enzyme HVP paths. |
ext/DynamicPPLExt.jl |
Updates Turing/DynamicPPL DensityModel convenience constructor defaulting to AutoForwardDiff(). |
test/test-Owned-Matmul.jl |
New tests validating owned wrappers and Enzyme rules (forward + reverse + Const-arg branches). |
test/test-GPU-AD-HVP.jl |
New tests for GPU AD-HVP fallback across Enzyme/Mooncake/Zygote backends. |
test/test-GPU-Performance.jl |
New GPU-vs-CPU performance regression sanity test. |
test/test-Turing-Integration.jl |
Expands DynamicPPL/Turing integration tests and passes explicit sampler backend. |
test/test-DEER-Interface.jl |
Updates DEER interface tests for required sampler backend. |
test/test-DEER-Turing-Logistic.jl |
Updates logistic regression tests for explicit backend and adjusts tolerances/params for stability. |
test/test-Deer-vs-MALA.jl |
Updates internal AD-HVP calls to use explicit AutoEnzyme() backend. |
test/test-Jacobian-Estimator.jl |
Updates tests to avoid removed DEFAULT_BACKEND and use explicit backend. |
test/test-MALA-Kernel.jl |
Formatting-only refactors in MALA kernel tests. |
test/test-GPU-MALA.jl |
Improves determinism (CUDA.seed!) and adjusts stationary mean tolerance. |
test/test-Adaptive-MALA.jl |
Comment formatting change. |
test/test-Code-Quality.jl |
Enables Aqua + JET checks in the test suite. |
docs/src/10-getting-started.md |
Updates getting-started narrative and points to the new GPU execution page. |
docs/src/15-gpu.md |
New GPU execution page covering limitations, backend choices, and a worked logistic regression example. |
docs/src/assets/make_julia_deer_gif.jl |
Formatting-only refactor in docs asset script. |
benchmarks/ParallelMCMCBenchmarks/Project.toml |
Adds Enzyme dependency and normalizes [sources] layout. |
benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl |
Updates benchmark suite to pass explicit sampler backend. |
benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl |
Comment formatting tweaks in benchmark model. |
benchmarks/ParallelMCMCBenchmarks/src/models/bayes_linreg.jl |
Comment formatting tweaks in benchmark model. |
benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl |
Updates profiling script to pass explicit backend in DEER rec + sampler. |
benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl |
Updates helper default backend to AutoEnzyme(). |
benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl |
Updates helper default backend to AutoEnzyme(). |
benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl |
Updates DEER sampler construction to pass explicit backend. |
benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl |
Updates GPU DEER benchmarking to pass explicit backend. |
benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl |
Formatting-only refactor. |
benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl |
Formatting-only refactor. |
.gitignore |
Ignores local debugging script directories (/scripts, /dev). |
Comments suppressed due to low confidence (1)
ext/DynamicPPLExt.jl:18
- The docstring says only
DynamicPPLmust be loaded, but the extension is configured to load only whenDynamicPPL,ForwardDiff, andLogDensityProblemsare all loaded (Project.toml). Either update the docstring to reflect the actual requirement (and mention thatForwardDiffis needed for the defaultAutoForwardDiff()), or loosen the extension deps and add a runtime check/error whenAutoForwardDiff()is selected withoutForwardDiffloaded.
Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a
`DensityModel`, automatically extracting parameter names and wiring up gradient
computation via DynamicPPL's `adtype` interface.
Requires `DynamicPPL` to be loaded.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Hi @penelopeysm -- could I get a quick review on some of the statements I wrote in the docs about DynamicPPL? I want to make it clear that ParallelMCMC on the GPU does not play nicely with Turing because, as far as I'm aware, DynamicPPl is not set up at the moment to handle GPU bound models. I think this is in the works, but you'd know better than I. If its incorrect in anyway please lmk! Thanks! |
|
That is reasonable to me, fwiw at one point when I was still on the project I was told that GPU compatibility was not a priority |
|
Hmm okay. all the more reason to integrate FlexiChains with first-class support I guess |
|
@gdalle can I get a last quick set of eyes on this before I merge in? It seems that the Enzyme LinAlg + GPU is a real limitation. My thought is to use the current I have, and once I get an okay to merge upstream, I'll strip it out. Just have a few people pinging me to use the package so wanna get this working. Thanks! |
|
I'll review tomorrow, sorry for the delay! |
| Default is `ForwardOnGrad()` since most DI backends support forward mode. | ||
| We override to `ReverseOnGrad()` for reverse-only backends. AutoEnzyme stays | ||
| on forward because Enzyme.Forward is robust on CuArrays once the matmul is | ||
| wrapped. |
There was a problem hiding this comment.
I guess it depends whether you want to consider "backends" (and thus separate AutoMooncake from AutoMooncakeForward, AutoEnzyme(; mode=Forward) from its reverse counterpart) or "packages", in which case Enzyme and Mooncake are bidirectional
There was a problem hiding this comment.
My impulse is to design around "backends". Thoughts?
There was a problem hiding this comment.
It makes more sense based on the existing ADTypes interface, but you might need to use DI.SecondOrder
There was a problem hiding this comment.
Given SecondOrder is failing I think this will be a future issue to address. So for now, I think backends will be whats considered. though, I'll edit this comment to recognize this split
There was a problem hiding this comment.
Can you paste the error you get with SecondOrder and an MWE?
There was a problem hiding this comment.
I pasted a script above--let me refashion it into just the specific SecondOrder case and I'll paste the error
There was a problem hiding this comment.
using CUDA
using Mooncake
using ParallelMCMC
const DI = ParallelMCMC.DEER.DI
using Random
logp(x) = -0.5f0 * sum(abs2, x)
rng = MersenneTwister(20251231)
x_gpu = CUDA.CuArray(randn(rng, Float32, 8))
v_gpu = CUDA.CuArray(ones(Float32, 8))
ref = -Array(v_gpu)
be = DI.SecondOrder(DI.AutoMooncakeForward(), DI.AutoMooncake())
DI.hvp(logp, be, x_gpu, (v_gpu,))ERROR: Mooncake.IntrinsicsWrappers.MissingIntrinsicWrapperException("Unable to translate the intrinsic Val{Core.Intrinsics.atomic_pointerref}() into a regular Julia function. Please see github.com/chalk-lab/Mooncake.jl/issues/208 for more discussion.")
Stacktrace:
[1] translate(f::Val{Core.Intrinsics.atomic_pointerref})
@ Mooncake.IntrinsicsWrappers /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/rules/builtins.jl:128
[2] lift_intrinsic(::Core.IntrinsicFunction, ::Core.SSAValue, ::QuoteNode, ::Vararg{Any})
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/ir_normalisation.jl:309
[3] intrinsic_to_function
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/ir_normalisation.jl:300 [inlined]
[4] normalise!(ir::Compiler.IRCode, spnames::Vector{Symbol})
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/ir_normalisation.jl:35
[5] generate_dual_ir(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, do_inline::Bool, do_optimize::Bool)
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:196
[6] build_frule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Any; debug_mode::Bool, silence_debug_messages::Bool, skip_world_age_check::Bool)
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:98
[7] _build_rule!(rule::Mooncake.LazyFRule{…}, args::Tuple{…})
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:532
[8] LazyFRule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:527 [inlined]
[9] cufunction
@ /projectnb/ssmsvi/rsenne/.julia/packages/CUDA/bjncr/src/compiler/execution.jl:365 [inlined]
[10] (::Tuple{…})(_2::Mooncake.Dual{…}, _3::Mooncake.Dual{…}, _4::Mooncake.Dual{…}, _5::Mooncake.Dual{…}, _6::Mooncake.Dual{…})
@ Base.Experimental ./<missing>:0
[11] (::MistyClosures.MistyClosure{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
@ MistyClosures /projectnb/ssmsvi/rsenne/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
[12] (::Mooncake.DerivedFRule{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:122
[13] __call_rule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/utils.jl:639 [inlined]
[14] (::Mooncake.DynamicFRule{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:582
[15] macro expansion
@ /projectnb/ssmsvi/rsenne/.julia/packages/CUDA/bjncr/src/compiler/execution.jl:112 [inlined]
[16] #_#6
@ /projectnb/ssmsvi/rsenne/.julia/packages/CUDA/bjncr/src/CUDAKernels.jl:129 [inlined]
[17] (::Tuple{…})(_2::Mooncake.Dual{…}, _3::Mooncake.Dual{…}, _4::Mooncake.Dual{…}, _5::Mooncake.Dual{…}, _6::Mooncake.Dual)
@ Base.Experimental ./<missing>:0
[18] (::MistyClosures.MistyClosure{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
@ MistyClosures /projectnb/ssmsvi/rsenne/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
[19] DerivedFRule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:122 [inlined]
[20] __call_rule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/utils.jl:639 [inlined]
[21] _build_rule!(rule::Mooncake.LazyFRule{…}, args::Tuple{…})
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:533
[22] LazyFRule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:527 [inlined]
[23] value_and_gradient
@ /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:133 [inlined]
[24] (::Tuple{…})(_2::Mooncake.Dual{…}, _3::Mooncake.Dual{…}, _4::Mooncake.Dual{…}, _5::Mooncake.Dual{…}, _6::Mooncake.Dual{…}, _7::Mooncake.Dual{…})
@ Base.Experimental ./<missing>:0
[25] (::MistyClosures.MistyClosure{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
@ MistyClosures /projectnb/ssmsvi/rsenne/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
[26] DerivedFRule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:122 [inlined]
[27] __call_rule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/utils.jl:639 [inlined]
[28] _build_rule!(rule::Mooncake.LazyFRule{…}, args::Tuple{…})
@ Mooncake /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:533
[29] LazyFRule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:527 [inlined]
[30] shuffled_gradient
@ /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/src/first_order/gradient.jl:171 [inlined]
[31] (::Tuple{…})(_2::Mooncake.Dual{…}, _3::Mooncake.Dual{…}, _4::Mooncake.Dual{…}, _5::Mooncake.Dual{…}, _6::Mooncake.Dual{…}, _7::Mooncake.Dual{…}, _8::Mooncake.Dual{…})
@ Base.Experimental ./<missing>:0
[32] (::MistyClosures.MistyClosure{…})(::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…}, ::Mooncake.Dual{…})
@ MistyClosures /projectnb/ssmsvi/rsenne/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
[33] DerivedFRule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interpreter/forward_mode.jl:122 [inlined]
[34] __call_rule
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/utils.jl:639 [inlined]
[35] value_and_derivative!!
@ /projectnb/ssmsvi/rsenne/.julia/packages/Mooncake/O5brV/src/interface.jl:2128 [inlined]
[36] #17
@ /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl:37 [inlined]
[37] map
@ ./tuple.jl:358 [inlined]
[38] value_and_pushforward(::typeof(DifferentiationInterface.shuffled_gradient), ::DifferentiationInterfaceMooncakeExt.MooncakeOneArgPushforwardPrep{…}, ::ADTypes.AutoMooncakeForward{…}, ::CuArray{…}, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.ConstantOrCache{…}, ::DifferentiationInterface.Constant{…}, ::DifferentiationInterface.Constant{…})
@ DifferentiationInterfaceMooncakeExt /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl:36
[39] pushforward
@ /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl:61 [inlined]
[40] hvp(::typeof(logp), ::DifferentiationInterface.ForwardOverAnythingHVPPrep{…}, ::DifferentiationInterface.SecondOrder{…}, ::CuArray{…}, ::Tuple{…})
@ DifferentiationInterface /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/src/second_order/hvp.jl:331
[41] hvp(::typeof(logp), ::DifferentiationInterface.SecondOrder{…}, ::CuArray{…}, ::Tuple{…})
@ DifferentiationInterface /projectnb/ssmsvi/rsenne/.julia/packages/DifferentiationInterface/f85Vv/src/second_order/hvp.jl:75
[42] top-level scope
@ /projectnb/ssmsvi/rsenne/ParallelMCMC.jl/scripts/just_second_order.jl:15
Some type information was truncated. Use `show(err)` to see complete types.let me know if anything i did was incorrect!
There was a problem hiding this comment.
Also--you get the same error for Mooncake when using the native API via:
cache = Mooncake.prepare_hvp_cache(logp, x_gpu)
Mooncake.value_and_hvp!!(cache, logp, v_gpu, x_gpu)|
okay--i was able to remove |
|
hiya @gdalle, sorry to bother you again, did you get a chance to look at this? |
This branch lands the GPU + AD-backend overhaul on top of main. The big themes:
Modular AD via DifferentiationInterface
EnzymeExt: GPU-safe Enzyme rules
Tests
Needed Changes Prior to Merging
Resolves #29 and provides a workaround to #25