diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 31bfab6b2..ec217248d 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -2,7 +2,7 @@ module LinearSolveForwardDiffExt using LinearSolve using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver, - DefaultAlgorithmChoice, defaultalg + DefaultAlgorithmChoice, defaultalg, reinit! using LinearAlgebra using ForwardDiff using ForwardDiff: Dual, Partials @@ -347,6 +347,78 @@ function setu!(dc::DualLinearCache, u) partial_vals!(getfield(dc, :partials_u), u) # Update in-place end +function SciMLBase.reinit!(cache::DualLinearCache; + A = nothing, + b = nothing, + u = nothing, + p = nothing, + reuse_precs = false) + linear_cache = getfield(cache, :linear_cache) + + # Compute freshness flags like in LinearCache reinit! + isfresh = !isnothing(A) + precsisfresh = !reuse_precs && (isfresh || !isnothing(p)) + isfresh |= linear_cache.isfresh + precsisfresh |= linear_cache.precsisfresh + + # Update A if provided + if !isnothing(A) + # Update the primal value in linear_cache + nodual_value!(linear_cache.A, A) + # Update dual_A + setfield!(cache, :dual_A, A) + # Update partials_A + partial_vals!(getfield(cache, :partials_A), A) + # Update partials_A_list from new partials + partials_A = getfield(cache, :partials_A) + partials_A_list = getfield(cache, :partials_A_list) + if !isnothing(partials_A) && !isnothing(partials_A_list) + update_partials_list!(partials_A, partials_A_list) + end + # Invalidate RHS cache + setfield!(cache, :rhs_cache_valid, false) + end + + # Update b if provided + if !isnothing(b) + # Update the primal value in linear_cache + nodual_value!(linear_cache.b, b) + # Update dual_b + setfield!(cache, :dual_b, b) + # Update partials_b + partial_vals!(getfield(cache, :partials_b), b) + # Update partials_b_list from new partials + partials_b = getfield(cache, :partials_b) + partials_b_list = getfield(cache, :partials_b_list) + if !isnothing(partials_b) && !isnothing(partials_b_list) + update_partials_list!(partials_b, partials_b_list) + end + # Invalidate RHS cache + setfield!(cache, :rhs_cache_valid, false) + end + + # Update u if provided + if !isnothing(u) + # Update the primal value in linear_cache + nodual_value!(linear_cache.u, u) + # Update dual_u + setfield!(cache, :dual_u, u) + # Update partials_u + partial_vals!(getfield(cache, :partials_u), u) + end + + # Update p if provided + if !isnothing(p) + linear_cache.p = p + end + + # Set freshness flags on linear_cache + linear_cache.isfresh = true + linear_cache.precsisfresh = precsisfresh + + nothing +end + function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A @@ -395,7 +467,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place nodual_value(x) = x nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x) nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact -nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x) +function nodual_value(x::AbstractArray{<:Dual}) + nodual_value!(similar(x, typeof(nodual_value(first(x)))), x) +end nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T} diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 4dd936873..598496b6e 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -189,6 +189,34 @@ backslash_x_p = A \ b @test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) +# Test reinit! for DualLinearCache +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +prob = LinearProblem(sparse(A), sparse(b)) +cache = init(prob, UMFPACKFactorization()) +overload_x_p = solve!(cache) +backslash_x_p = A \ b +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) + +# Now use reinit! to update A +new_A, new_b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) +reinit!(cache; A = sparse(new_A)) +overload_x_p = solve!(cache, UMFPACKFactorization()) +backslash_x_p = new_A \ b +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) + +# Test reinit! with both A and b +reinit!(cache; A = sparse(new_A), b = sparse(new_b)) +overload_x_p = solve!(cache, UMFPACKFactorization()) +backslash_x_p = new_A \ new_b +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) + +# Test reinit! with just b +A2, b2 = h([ForwardDiff.Dual(7.0, 1.0, 0.0), ForwardDiff.Dual(7.0, 0.0, 1.0)]) +reinit!(cache; b = sparse(b2)) +overload_x_p = solve!(cache, UMFPACKFactorization()) +backslash_x_p = new_A \ b2 +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) + # Test that GenericLU doesn't create a DualLinearCache A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])