Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module LinearSolveForwardDiffExt

using LinearSolve
using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
DefaultAlgorithmChoice, defaultalg
DefaultAlgorithmChoice, defaultalg, reinit!
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Partials
Expand Down Expand Up @@ -347,6 +347,78 @@ function setu!(dc::DualLinearCache, u)
partial_vals!(getfield(dc, :partials_u), u) # Update in-place
end

function SciMLBase.reinit!(cache::DualLinearCache;
A = nothing,
b = nothing,
u = nothing,
p = nothing,
reuse_precs = false)
linear_cache = getfield(cache, :linear_cache)

# Compute freshness flags like in LinearCache reinit!
isfresh = !isnothing(A)
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
isfresh |= linear_cache.isfresh
precsisfresh |= linear_cache.precsisfresh

# Update A if provided
if !isnothing(A)
# Update the primal value in linear_cache
nodual_value!(linear_cache.A, A)
# Update dual_A
setfield!(cache, :dual_A, A)
# Update partials_A
partial_vals!(getfield(cache, :partials_A), A)
# Update partials_A_list from new partials
partials_A = getfield(cache, :partials_A)
partials_A_list = getfield(cache, :partials_A_list)
if !isnothing(partials_A) && !isnothing(partials_A_list)
update_partials_list!(partials_A, partials_A_list)
end
# Invalidate RHS cache
setfield!(cache, :rhs_cache_valid, false)
end

# Update b if provided
if !isnothing(b)
# Update the primal value in linear_cache
nodual_value!(linear_cache.b, b)
# Update dual_b
setfield!(cache, :dual_b, b)
# Update partials_b
partial_vals!(getfield(cache, :partials_b), b)
# Update partials_b_list from new partials
partials_b = getfield(cache, :partials_b)
partials_b_list = getfield(cache, :partials_b_list)
if !isnothing(partials_b) && !isnothing(partials_b_list)
update_partials_list!(partials_b, partials_b_list)
end
# Invalidate RHS cache
setfield!(cache, :rhs_cache_valid, false)
end

# Update u if provided
if !isnothing(u)
# Update the primal value in linear_cache
nodual_value!(linear_cache.u, u)
# Update dual_u
setfield!(cache, :dual_u, u)
# Update partials_u
partial_vals!(getfield(cache, :partials_u), u)
end

# Update p if provided
if !isnothing(p)
linear_cache.p = p
end

# Set freshness flags on linear_cache
linear_cache.isfresh = true
linear_cache.precsisfresh = precsisfresh

nothing
end

function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
# If the property is A or b, also update it in the LinearCache
if sym === :A
Expand Down Expand Up @@ -395,7 +467,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
nodual_value(x) = x
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
function nodual_value(x::AbstractArray{<:Dual})
nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
end
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place

function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T}
Expand Down
28 changes: 28 additions & 0 deletions test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,34 @@ backslash_x_p = A \ b

@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)

# Test reinit! for DualLinearCache
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
prob = LinearProblem(sparse(A), sparse(b))
cache = init(prob, UMFPACKFactorization())
overload_x_p = solve!(cache)
backslash_x_p = A \ b
@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)

# Now use reinit! to update A
new_A, new_b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])
reinit!(cache; A = sparse(new_A))
overload_x_p = solve!(cache, UMFPACKFactorization())
backslash_x_p = new_A \ b
@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)

# Test reinit! with both A and b
reinit!(cache; A = sparse(new_A), b = sparse(new_b))
overload_x_p = solve!(cache, UMFPACKFactorization())
backslash_x_p = new_A \ new_b
@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)

# Test reinit! with just b
A2, b2 = h([ForwardDiff.Dual(7.0, 1.0, 0.0), ForwardDiff.Dual(7.0, 0.0, 1.0)])
reinit!(cache; b = sparse(b2))
overload_x_p = solve!(cache, UMFPACKFactorization())
backslash_x_p = new_A \ b2
@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)

# Test that GenericLU doesn't create a DualLinearCache
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

Expand Down
Loading