Skip to content

Commit 2d1c22f

Browse files
committed
make it check if it's square in general, add test for original equation
1 parent f8398c3 commit 2d1c22f

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,10 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
6363
end
6464

6565
function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kwargs...)
66-
# Check if we're solving an overdetermined system (more rows than columns in A)
66+
# Check if A is square - if not, use the non-square system path
6767
A = cache.linear_cache.A
68-
is_overdetermined = size(A, 1) > size(A, 2)
6968

70-
if is_overdetermined
69+
if !issquare(A)
7170
# For overdetermined systems, differentiate the normal equations: A'Ax = A'b
7271
# Taking d/dθ of both sides:
7372
# dA'/dθ · Ax + A' · dA/dθ · x + A'A · dx/dθ = dA'/dθ · b + A' · db/dθ

test/forwarddiff_overloads.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ backslash_large = A_large_dual \ b_large_dual
284284
# Test primal values match
285285
@test ForwardDiff.value.(sol_large.u) ForwardDiff.value.(backslash_large)
286286

287+
@test A_large_dual' * A_large_dual * sol_large.u A_large_dual' * b_large_dual
288+
@test A_large_dual' * A_large_dual * backslash_large A_large_dual' * b_large_dual
289+
287290
# Test partials match
288291
@test ForwardDiff.partials.(sol_large.u) ForwardDiff.partials.(backslash_large)
289292

0 commit comments

Comments
 (0)