@@ -63,6 +63,78 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
6363end
6464
6565function linearsolve_forwarddiff_solve! (cache:: DualLinearCache , alg, args... ; kwargs... )
66+ # Check if A is square - if not, use the non-square system path
67+ A = cache. linear_cache. A
68+
69+ if ! issquare (A)
70+ # For overdetermined systems, differentiate the normal equations: A'Ax = A'b
71+ # Taking d/dθ of both sides:
72+ # dA'/dθ · Ax + A' · dA/dθ · x + A'A · dx/dθ = dA'/dθ · b + A' · db/dθ
73+ # Rearranging:
74+ # A'A · dx/dθ = A' · db/dθ + dA'/dθ · (b - Ax) - A' · dA/dθ · x
75+
76+ # Solve the primal problem first
77+ cache. dual_u0_cache .= cache. linear_cache. u
78+ sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
79+ cache. primal_u_cache .= cache. linear_cache. u
80+ cache. primal_b_cache .= cache. linear_cache. b
81+ u = sol. u
82+
83+ # Get the partials and primal values
84+ # After solve!, cache.linear_cache.A may be modified by factorization,
85+ # so we extract primal A from the original dual_A stored in cache
86+ ∂_A = cache. partials_A
87+ ∂_b = cache. partials_b
88+ A = nodual_value (cache. dual_A)
89+ A_adj = A'
90+ b = cache. primal_b_cache
91+ residual = b - A * u # residual r = b - Ax
92+
93+ rhs_list = cache. rhs_list
94+
95+ # Update cached partials lists if cache is invalid
96+ if ! cache. rhs_cache_valid
97+ if ! isnothing (∂_A)
98+ update_partials_list! (∂_A, cache. partials_A_list)
99+ end
100+ if ! isnothing (∂_b)
101+ update_partials_list! (∂_b, cache. partials_b_list)
102+ end
103+ cache. rhs_cache_valid = true
104+ end
105+
106+ A_list = cache. partials_A_list
107+ b_list = cache. partials_b_list
108+
109+ # Compute RHS: A' · db/dθ + dA'/dθ · (b - Ax) - A' · dA/dθ · x
110+ for i in eachindex (rhs_list)
111+ if ! isnothing (b_list)
112+ # A' · db/dθ
113+ rhs_list[i] .= A_adj * b_list[i]
114+ else
115+ fill! (rhs_list[i], 0 )
116+ end
117+
118+ if ! isnothing (A_list)
119+ # Add dA'/dθ · (b - Ax) = (dA/dθ)' · residual
120+ rhs_list[i] .+ = A_list[i]' * residual
121+ # Subtract A' · dA/dθ · x
122+ temp = A_list[i] * u
123+ rhs_list[i] .- = A_adj * temp
124+ end
125+ end
126+
127+ for i in eachindex (rhs_list)
128+ cache. linear_cache. b .= A_adj \ rhs_list[i]
129+ rhs_list[i] .= solve! (cache. linear_cache, alg, args... ; kwargs... ). u
130+ end
131+
132+ cache. linear_cache. b .= cache. primal_b_cache
133+ cache. linear_cache. u .= cache. primal_u_cache
134+
135+ return sol
136+ end
137+
66138 # Solve the primal problem
67139 cache. dual_u0_cache .= cache. linear_cache. u
68140 sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
@@ -71,14 +143,13 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
71143 cache. primal_b_cache .= cache. linear_cache. b
72144 uu = sol. u
73145
74- # Solves Dual partials separately
146+ # Solves Dual partials separately
75147 ∂_A = cache. partials_A
76148 ∂_b = cache. partials_b
77149
78150 xp_linsolve_rhs! (uu, ∂_A, ∂_b, cache)
79151
80152 rhs_list = cache. rhs_list
81-
82153 cache. linear_cache. u .= cache. dual_u0_cache
83154 # We can reuse the linear cache, because the same factorization will work for the partials.
84155 for i in eachindex (rhs_list)
@@ -177,7 +248,6 @@ function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray,
177248 partials) where {T, V, N, DT <: Dual{T, V, N} }
178249 # Direct in-place construction of dual numbers without temporary allocations
179250 n_partials = length (partials)
180-
181251 for i in eachindex (u, dual_u)
182252 # Extract partials for this element directly
183253 partial_vals = ntuple (Val (N)) do j
@@ -263,15 +333,26 @@ function __dual_init(
263333 partials_b_list = ! isnothing (∂_b) ? partials_to_list (∂_b) : nothing
264334
265335 # Determine size and type for rhs_list
336+ # For square systems, use b size. For overdetermined, use u size (solution size)
337+ rhs_template = length (non_partial_cache. u) == length (non_partial_cache. b) ?
338+ non_partial_cache. b : non_partial_cache. u
339+
266340 if ! isnothing (partials_A_list)
267341 n_partials = length (partials_A_list)
268- rhs_list = [similar (non_partial_cache . b ) for _ in 1 : n_partials]
342+ rhs_list = [similar (rhs_template ) for _ in 1 : n_partials]
269343 elseif ! isnothing (partials_b_list)
270344 n_partials = length (partials_b_list)
271- rhs_list = [similar (non_partial_cache . b ) for _ in 1 : n_partials]
345+ rhs_list = [similar (rhs_template ) for _ in 1 : n_partials]
272346 else
273347 rhs_list = nothing
274348 end
349+ # Use b for restructuring if sizes match (square system), otherwise use u (non-square)
350+ # This preserves ComponentArray structure from b when possible
351+ dual_u_init = if length (non_partial_cache. u) == length (b)
352+ ArrayInterface. restructure (b, zeros (dual_type, length (b)))
353+ else
354+ ArrayInterface. restructure (non_partial_cache. u, zeros (dual_type, length (non_partial_cache. u)))
355+ end
275356
276357 return DualLinearCache {dual_type} (
277358 non_partial_cache,
@@ -281,13 +362,13 @@ function __dual_init(
281362 partials_A_list,
282363 partials_b_list,
283364 rhs_list,
284- similar (new_b),
285- similar (new_b),
286- similar (new_b),
365+ similar (non_partial_cache . u), # Use u's size, not b's size
366+ similar (non_partial_cache . u), # primal_u_cache
367+ similar (new_b), # primal_b_cache
287368 true , # Cache is initially valid
288369 A,
289370 b,
290- ArrayInterface . restructure (b, zeros (dual_type, length (b)))
371+ dual_u_init
291372 )
292373end
293374
@@ -300,6 +381,8 @@ function SciMLBase.solve!(
300381 ForwardDiff. Dual}
301382 primal_sol = linearsolve_forwarddiff_solve! (
302383 cache:: DualLinearCache , getfield (cache, :linear_cache ). alg, args... ; kwargs... )
384+
385+ # Construct dual solution from primal solution and partials
303386 dual_sol = linearsolve_dual_solution (getfield (cache, :linear_cache ). u, getfield (cache, :rhs_list ), cache)
304387
305388 # For scalars, we still need to assign since cache.dual_u might not be pre-allocated
0 commit comments