Skip to content

Commit dac8f53

Browse files
Merge pull request #836 from j-fu/dualreinit
Attempt the implementation of missing reinit! for dual cache
2 parents a5f623e + b1ffebf commit dac8f53

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module LinearSolveForwardDiffExt
22

33
using LinearSolve
44
using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
5-
DefaultAlgorithmChoice, defaultalg
5+
DefaultAlgorithmChoice, defaultalg, reinit!
66
using LinearAlgebra
77
using ForwardDiff
88
using ForwardDiff: Dual, Partials
@@ -430,6 +430,38 @@ function setu!(dc::DualLinearCache, u)
430430
partial_vals!(getfield(dc, :partials_u), u) # Update in-place
431431
end
432432

433+
function SciMLBase.reinit!(cache::DualLinearCache;
434+
A = nothing,
435+
b = nothing,
436+
u = nothing,
437+
p = nothing,
438+
reuse_precs = false)
439+
if !isnothing(A)
440+
setA!(cache, A)
441+
end
442+
443+
if !isnothing(b)
444+
setb!(cache, b)
445+
end
446+
447+
if !isnothing(u)
448+
setu!(cache, u)
449+
end
450+
451+
if !isnothing(p)
452+
cache.linear_cache.p=p
453+
end
454+
455+
isfresh = !isnothing(A)
456+
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
457+
isfresh |= cache.linear_cache.isfresh
458+
precsisfresh |= cache.linear_cache.precsisfresh
459+
cache.linear_cache.isfresh = true
460+
cache.linear_cache.precsisfresh = precsisfresh
461+
462+
nothing
463+
end
464+
433465
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
434466
# If the property is A or b, also update it in the LinearCache
435467
if sym === :A
@@ -478,7 +510,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
478510
nodual_value(x) = x
479511
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
480512
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
481-
nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
513+
function nodual_value(x::AbstractArray{<:Dual})
514+
nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
515+
end
482516
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place
483517

484518
function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T}

test/forwarddiff_overloads.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,13 @@ backslash_x_p = A \ b
189189

190190
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
191191

192+
A[1, 1]+=2
193+
cache = overload_x_p.cache
194+
reinit!(cache; A = sparse(A))
195+
overload_x_p = solve!(cache, UMFPACKFactorization())
196+
backslash_x_p = A \ b
197+
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
198+
192199
# Test that GenericLU doesn't create a DualLinearCache
193200
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
194201

0 commit comments

Comments
 (0)