diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 31bfab6b2..2d0d1cbc3 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -63,6 +63,78 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT} end function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kwargs...) + # Check if A is square - if not, use the non-square system path + A = cache.linear_cache.A + + if !issquare(A) + # For overdetermined systems, differentiate the normal equations: A'Ax = A'b + # Taking d/dθ of both sides: + # dA'/dθ · Ax + A' · dA/dθ · x + A'A · dx/dθ = dA'/dθ · b + A' · db/dθ + # Rearranging: + # A'A · dx/dθ = A' · db/dθ + dA'/dθ · (b - Ax) - A' · dA/dθ · x + + # Solve the primal problem first + cache.dual_u0_cache .= cache.linear_cache.u + sol = solve!(cache.linear_cache, alg, args...; kwargs...) + cache.primal_u_cache .= cache.linear_cache.u + cache.primal_b_cache .= cache.linear_cache.b + u = sol.u + + # Get the partials and primal values + # After solve!, cache.linear_cache.A may be modified by factorization, + # so we extract primal A from the original dual_A stored in cache + ∂_A = cache.partials_A + ∂_b = cache.partials_b + A = nodual_value(cache.dual_A) + A_adj = A' + b = cache.primal_b_cache + residual = b - A * u # residual r = b - Ax + + rhs_list = cache.rhs_list + + # Update cached partials lists if cache is invalid + if !cache.rhs_cache_valid + if !isnothing(∂_A) + update_partials_list!(∂_A, cache.partials_A_list) + end + if !isnothing(∂_b) + update_partials_list!(∂_b, cache.partials_b_list) + end + cache.rhs_cache_valid = true + end + + A_list = cache.partials_A_list + b_list = cache.partials_b_list + + # Compute RHS: A' · db/dθ + dA'/dθ · (b - Ax) - A' · dA/dθ · x + for i in eachindex(rhs_list) + if !isnothing(b_list) + # A' · db/dθ + rhs_list[i] .= A_adj * b_list[i] + else + fill!(rhs_list[i], 0) + end + + if !isnothing(A_list) + # Add dA'/dθ · (b - Ax) = (dA/dθ)' · residual + rhs_list[i] .+= A_list[i]' * residual + # Subtract A' · dA/dθ · x + temp = A_list[i] * u + rhs_list[i] .-= A_adj * temp + end + end + + for i in eachindex(rhs_list) + cache.linear_cache.b .= A_adj \ rhs_list[i] + rhs_list[i] .= solve!(cache.linear_cache, alg, args...; kwargs...).u + end + + cache.linear_cache.b .= cache.primal_b_cache + cache.linear_cache.u .= cache.primal_u_cache + + return sol + end + # Solve the primal problem cache.dual_u0_cache .= cache.linear_cache.u sol = solve!(cache.linear_cache, alg, args...; kwargs...) @@ -71,14 +143,13 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw cache.primal_b_cache .= cache.linear_cache.b uu = sol.u - # Solves Dual partials separately + # Solves Dual partials separately ∂_A = cache.partials_A ∂_b = cache.partials_b xp_linsolve_rhs!(uu, ∂_A, ∂_b, cache) rhs_list = cache.rhs_list - cache.linear_cache.u .= cache.dual_u0_cache # We can reuse the linear cache, because the same factorization will work for the partials. for i in eachindex(rhs_list) @@ -177,7 +248,6 @@ function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray, partials) where {T, V, N, DT <: Dual{T, V, N}} # Direct in-place construction of dual numbers without temporary allocations n_partials = length(partials) - for i in eachindex(u, dual_u) # Extract partials for this element directly partial_vals = ntuple(Val(N)) do j @@ -263,15 +333,26 @@ function __dual_init( partials_b_list = !isnothing(∂_b) ? partials_to_list(∂_b) : nothing # Determine size and type for rhs_list + # For square systems, use b size. For overdetermined, use u size (solution size) + rhs_template = length(non_partial_cache.u) == length(non_partial_cache.b) ? + non_partial_cache.b : non_partial_cache.u + if !isnothing(partials_A_list) n_partials = length(partials_A_list) - rhs_list = [similar(non_partial_cache.b) for _ in 1:n_partials] + rhs_list = [similar(rhs_template) for _ in 1:n_partials] elseif !isnothing(partials_b_list) n_partials = length(partials_b_list) - rhs_list = [similar(non_partial_cache.b) for _ in 1:n_partials] + rhs_list = [similar(rhs_template) for _ in 1:n_partials] else rhs_list = nothing end + # Use b for restructuring if sizes match (square system), otherwise use u (non-square) + # This preserves ComponentArray structure from b when possible + dual_u_init = if length(non_partial_cache.u) == length(b) + ArrayInterface.restructure(b, zeros(dual_type, length(b))) + else + ArrayInterface.restructure(non_partial_cache.u, zeros(dual_type, length(non_partial_cache.u))) + end return DualLinearCache{dual_type}( non_partial_cache, @@ -281,13 +362,13 @@ function __dual_init( partials_A_list, partials_b_list, rhs_list, - similar(new_b), - similar(new_b), - similar(new_b), + similar(non_partial_cache.u), # Use u's size, not b's size + similar(non_partial_cache.u), # primal_u_cache + similar(new_b), # primal_b_cache true, # Cache is initially valid A, b, - ArrayInterface.restructure(b, zeros(dual_type, length(b))) + dual_u_init ) end @@ -300,6 +381,8 @@ function SciMLBase.solve!( ForwardDiff.Dual} primal_sol = linearsolve_forwarddiff_solve!( cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...) + + # Construct dual solution from primal solution and partials dual_sol = linearsolve_dual_solution(getfield(cache, :linear_cache).u, getfield(cache, :rhs_list), cache) # For scalars, we still need to assign since cache.dual_u might not be pre-allocated diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 4dd936873..3a81d98a8 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -241,3 +241,53 @@ grad = ForwardDiff.gradient(component_linsolve, p_test) @test length(grad) == 2 @test !any(isnan, grad) @test !any(isinf, grad) + +# Test overdetermined (non-square) system: 2×1 matrix with dual numbers +# This tests that cache sizes are correctly allocated when solution size != RHS size +A_overdet = reshape([ForwardDiff.Dual(2.0, 1.0), ForwardDiff.Dual(3.0, 1.0)], 2, 1) # 2×1 matrix +b_overdet = [ForwardDiff.Dual(5.0, 1.0), ForwardDiff.Dual(8.0, 9.0)] + +prob_overdet = LinearProblem(A_overdet, b_overdet) +sol_overdet = solve(prob_overdet) +backslash_overdet = A_overdet \ b_overdet + +# Test that solution has correct dimensions (length 1, not length 2) +@test length(sol_overdet.u) == 1 + +# Primal values should match +@test ForwardDiff.value.(sol_overdet.u) ≈ ForwardDiff.value.(backslash_overdet) + +# Dual values should match +@test ForwardDiff.partials.(sol_overdet.u) ≈ ForwardDiff.partials.(backslash_overdet) + +# Test with cache - should give identical results +cache_overdet = init(prob_overdet) +sol_cache_overdet = solve!(cache_overdet) +@test sol_cache_overdet.u ≈ sol_overdet.u + +# Dual values should match +@test ForwardDiff.partials.(sol_overdet.u) ≈ ForwardDiff.partials.(backslash_overdet) + +# Test larger overdetermined system with dual numbers +m, n = 10, 3 +A_large = rand(m, n) +p = [2.0, 3.0] +A_large_dual = [ForwardDiff.Dual(A_large[i, j], i == 1 ? 1.0 : 0.0, j == 1 ? 1.0 : 0.0) + for i in 1:m, j in 1:n] +b_large_dual = [ForwardDiff.Dual(rand(), i == 1 ? 1.0 : 0.0, i == 2 ? 1.0 : 0.0) + for i in 1:m] + +prob_large = LinearProblem(A_large_dual, b_large_dual) +sol_large = solve(prob_large) +backslash_large = A_large_dual \ b_large_dual + +# Test primal values match +@test ForwardDiff.value.(sol_large.u) ≈ ForwardDiff.value.(backslash_large) + +@test A_large_dual' * A_large_dual * sol_large.u ≈ A_large_dual' * b_large_dual +@test A_large_dual' * A_large_dual * backslash_large ≈ A_large_dual' * b_large_dual + +# Test partials match +@test ForwardDiff.partials.(sol_large.u) ≈ ForwardDiff.partials.(backslash_large) + +