@@ -2,7 +2,7 @@ module LinearSolveForwardDiffExt
22
33using LinearSolve
44using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
5- DefaultAlgorithmChoice, defaultalg
5+ DefaultAlgorithmChoice, defaultalg, reinit!
66using LinearAlgebra
77using ForwardDiff
88using ForwardDiff: Dual, Partials
@@ -430,6 +430,38 @@ function setu!(dc::DualLinearCache, u)
430430 partial_vals! (getfield (dc, :partials_u ), u) # Update in-place
431431end
432432
433+ function SciMLBase. reinit! (cache:: DualLinearCache ;
434+ A = nothing ,
435+ b = nothing ,
436+ u = nothing ,
437+ p = nothing ,
438+ reuse_precs = false )
439+ if ! isnothing (A)
440+ setA! (cache, A)
441+ end
442+
443+ if ! isnothing (b)
444+ setb! (cache, b)
445+ end
446+
447+ if ! isnothing (u)
448+ setu! (cache, u)
449+ end
450+
451+ if ! isnothing (p)
452+ cache. linear_cache. p= p
453+ end
454+
455+ isfresh = ! isnothing (A)
456+ precsisfresh = ! reuse_precs && (isfresh || ! isnothing (p))
457+ isfresh |= cache. linear_cache. isfresh
458+ precsisfresh |= cache. linear_cache. precsisfresh
459+ cache. linear_cache. isfresh = true
460+ cache. linear_cache. precsisfresh = precsisfresh
461+
462+ nothing
463+ end
464+
433465function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
434466 # If the property is A or b, also update it in the LinearCache
435467 if sym === :A
@@ -478,7 +510,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
478510nodual_value (x) = x
479511nodual_value (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. value (x)
480512nodual_value (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = x. value # Keep the inner dual intact
481- nodual_value (x:: AbstractArray{<:Dual} ) = nodual_value! (similar (x, typeof (nodual_value (first (x)))), x)
513+ function nodual_value (x:: AbstractArray{<:Dual} )
514+ nodual_value! (similar (x, typeof (nodual_value (first (x)))), x)
515+ end
482516nodual_value! (out, x) = map! (nodual_value, out, x) # Update in-place
483517
484518function update_partials_list! (partial_matrix:: AbstractVector{T} , list_cache) where {T}
0 commit comments