Skip to content

Commit a5f623e

Browse files
Merge pull request #842 from jClugstor/dual_numbers_fix
Add overloads for overdetermined system derivatives
2 parents 29c7aff + 2d1c22f commit a5f623e

File tree

2 files changed

+142
-9
lines changed

2 files changed

+142
-9
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,78 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
6363
end
6464

6565
function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kwargs...)
66+
# Check if A is square - if not, use the non-square system path
67+
A = cache.linear_cache.A
68+
69+
if !issquare(A)
70+
# For overdetermined systems, differentiate the normal equations: A'Ax = A'b
71+
# Taking d/dθ of both sides:
72+
# dA'/dθ · Ax + A' · dA/dθ · x + A'A · dx/dθ = dA'/dθ · b + A' · db/dθ
73+
# Rearranging:
74+
# A'A · dx/dθ = A' · db/dθ + dA'/dθ · (b - Ax) - A' · dA/dθ · x
75+
76+
# Solve the primal problem first
77+
cache.dual_u0_cache .= cache.linear_cache.u
78+
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
79+
cache.primal_u_cache .= cache.linear_cache.u
80+
cache.primal_b_cache .= cache.linear_cache.b
81+
u = sol.u
82+
83+
# Get the partials and primal values
84+
# After solve!, cache.linear_cache.A may be modified by factorization,
85+
# so we extract primal A from the original dual_A stored in cache
86+
∂_A = cache.partials_A
87+
∂_b = cache.partials_b
88+
A = nodual_value(cache.dual_A)
89+
A_adj = A'
90+
b = cache.primal_b_cache
91+
residual = b - A * u # residual r = b - Ax
92+
93+
rhs_list = cache.rhs_list
94+
95+
# Update cached partials lists if cache is invalid
96+
if !cache.rhs_cache_valid
97+
if !isnothing(∂_A)
98+
update_partials_list!(∂_A, cache.partials_A_list)
99+
end
100+
if !isnothing(∂_b)
101+
update_partials_list!(∂_b, cache.partials_b_list)
102+
end
103+
cache.rhs_cache_valid = true
104+
end
105+
106+
A_list = cache.partials_A_list
107+
b_list = cache.partials_b_list
108+
109+
# Compute RHS: A' · db/dθ + dA'/dθ · (b - Ax) - A' · dA/dθ · x
110+
for i in eachindex(rhs_list)
111+
if !isnothing(b_list)
112+
# A' · db/dθ
113+
rhs_list[i] .= A_adj * b_list[i]
114+
else
115+
fill!(rhs_list[i], 0)
116+
end
117+
118+
if !isnothing(A_list)
119+
# Add dA'/dθ · (b - Ax) = (dA/dθ)' · residual
120+
rhs_list[i] .+= A_list[i]' * residual
121+
# Subtract A' · dA/dθ · x
122+
temp = A_list[i] * u
123+
rhs_list[i] .-= A_adj * temp
124+
end
125+
end
126+
127+
for i in eachindex(rhs_list)
128+
cache.linear_cache.b .= A_adj \ rhs_list[i]
129+
rhs_list[i] .= solve!(cache.linear_cache, alg, args...; kwargs...).u
130+
end
131+
132+
cache.linear_cache.b .= cache.primal_b_cache
133+
cache.linear_cache.u .= cache.primal_u_cache
134+
135+
return sol
136+
end
137+
66138
# Solve the primal problem
67139
cache.dual_u0_cache .= cache.linear_cache.u
68140
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
@@ -71,14 +143,13 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
71143
cache.primal_b_cache .= cache.linear_cache.b
72144
uu = sol.u
73145

74-
# Solves Dual partials separately
146+
# Solves Dual partials separately
75147
∂_A = cache.partials_A
76148
∂_b = cache.partials_b
77149

78150
xp_linsolve_rhs!(uu, ∂_A, ∂_b, cache)
79151

80152
rhs_list = cache.rhs_list
81-
82153
cache.linear_cache.u .= cache.dual_u0_cache
83154
# We can reuse the linear cache, because the same factorization will work for the partials.
84155
for i in eachindex(rhs_list)
@@ -177,7 +248,6 @@ function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray,
177248
partials) where {T, V, N, DT <: Dual{T, V, N}}
178249
# Direct in-place construction of dual numbers without temporary allocations
179250
n_partials = length(partials)
180-
181251
for i in eachindex(u, dual_u)
182252
# Extract partials for this element directly
183253
partial_vals = ntuple(Val(N)) do j
@@ -263,15 +333,26 @@ function __dual_init(
263333
partials_b_list = !isnothing(∂_b) ? partials_to_list(∂_b) : nothing
264334

265335
# Determine size and type for rhs_list
336+
# For square systems, use b size. For overdetermined, use u size (solution size)
337+
rhs_template = length(non_partial_cache.u) == length(non_partial_cache.b) ?
338+
non_partial_cache.b : non_partial_cache.u
339+
266340
if !isnothing(partials_A_list)
267341
n_partials = length(partials_A_list)
268-
rhs_list = [similar(non_partial_cache.b) for _ in 1:n_partials]
342+
rhs_list = [similar(rhs_template) for _ in 1:n_partials]
269343
elseif !isnothing(partials_b_list)
270344
n_partials = length(partials_b_list)
271-
rhs_list = [similar(non_partial_cache.b) for _ in 1:n_partials]
345+
rhs_list = [similar(rhs_template) for _ in 1:n_partials]
272346
else
273347
rhs_list = nothing
274348
end
349+
# Use b for restructuring if sizes match (square system), otherwise use u (non-square)
350+
# This preserves ComponentArray structure from b when possible
351+
dual_u_init = if length(non_partial_cache.u) == length(b)
352+
ArrayInterface.restructure(b, zeros(dual_type, length(b)))
353+
else
354+
ArrayInterface.restructure(non_partial_cache.u, zeros(dual_type, length(non_partial_cache.u)))
355+
end
275356

276357
return DualLinearCache{dual_type}(
277358
non_partial_cache,
@@ -281,13 +362,13 @@ function __dual_init(
281362
partials_A_list,
282363
partials_b_list,
283364
rhs_list,
284-
similar(new_b),
285-
similar(new_b),
286-
similar(new_b),
365+
similar(non_partial_cache.u), # Use u's size, not b's size
366+
similar(non_partial_cache.u), # primal_u_cache
367+
similar(new_b), # primal_b_cache
287368
true, # Cache is initially valid
288369
A,
289370
b,
290-
ArrayInterface.restructure(b, zeros(dual_type, length(b)))
371+
dual_u_init
291372
)
292373
end
293374

@@ -300,6 +381,8 @@ function SciMLBase.solve!(
300381
ForwardDiff.Dual}
301382
primal_sol = linearsolve_forwarddiff_solve!(
302383
cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...)
384+
385+
# Construct dual solution from primal solution and partials
303386
dual_sol = linearsolve_dual_solution(getfield(cache, :linear_cache).u, getfield(cache, :rhs_list), cache)
304387

305388
# For scalars, we still need to assign since cache.dual_u might not be pre-allocated

test/forwarddiff_overloads.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,53 @@ grad = ForwardDiff.gradient(component_linsolve, p_test)
241241
@test length(grad) == 2
242242
@test !any(isnan, grad)
243243
@test !any(isinf, grad)
244+
245+
# Test overdetermined (non-square) system: 2×1 matrix with dual numbers
246+
# This tests that cache sizes are correctly allocated when solution size != RHS size
247+
A_overdet = reshape([ForwardDiff.Dual(2.0, 1.0), ForwardDiff.Dual(3.0, 1.0)], 2, 1) # 2×1 matrix
248+
b_overdet = [ForwardDiff.Dual(5.0, 1.0), ForwardDiff.Dual(8.0, 9.0)]
249+
250+
prob_overdet = LinearProblem(A_overdet, b_overdet)
251+
sol_overdet = solve(prob_overdet)
252+
backslash_overdet = A_overdet \ b_overdet
253+
254+
# Test that solution has correct dimensions (length 1, not length 2)
255+
@test length(sol_overdet.u) == 1
256+
257+
# Primal values should match
258+
@test ForwardDiff.value.(sol_overdet.u) ForwardDiff.value.(backslash_overdet)
259+
260+
# Dual values should match
261+
@test ForwardDiff.partials.(sol_overdet.u) ForwardDiff.partials.(backslash_overdet)
262+
263+
# Test with cache - should give identical results
264+
cache_overdet = init(prob_overdet)
265+
sol_cache_overdet = solve!(cache_overdet)
266+
@test sol_cache_overdet.u sol_overdet.u
267+
268+
# Dual values should match
269+
@test ForwardDiff.partials.(sol_overdet.u) ForwardDiff.partials.(backslash_overdet)
270+
271+
# Test larger overdetermined system with dual numbers
272+
m, n = 10, 3
273+
A_large = rand(m, n)
274+
p = [2.0, 3.0]
275+
A_large_dual = [ForwardDiff.Dual(A_large[i, j], i == 1 ? 1.0 : 0.0, j == 1 ? 1.0 : 0.0)
276+
for i in 1:m, j in 1:n]
277+
b_large_dual = [ForwardDiff.Dual(rand(), i == 1 ? 1.0 : 0.0, i == 2 ? 1.0 : 0.0)
278+
for i in 1:m]
279+
280+
prob_large = LinearProblem(A_large_dual, b_large_dual)
281+
sol_large = solve(prob_large)
282+
backslash_large = A_large_dual \ b_large_dual
283+
284+
# Test primal values match
285+
@test ForwardDiff.value.(sol_large.u) ForwardDiff.value.(backslash_large)
286+
287+
@test A_large_dual' * A_large_dual * sol_large.u A_large_dual' * b_large_dual
288+
@test A_large_dual' * A_large_dual * backslash_large A_large_dual' * b_large_dual
289+
290+
# Test partials match
291+
@test ForwardDiff.partials.(sol_large.u) ForwardDiff.partials.(backslash_large)
292+
293+

0 commit comments

Comments
 (0)