Skip to content

Commit 94358bd

Browse files
committed
avoid creating normal matrix
1 parent 2a88978 commit 94358bd

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
7272
# Taking d/dθ of both sides:
7373
# dA'/dθ · Ax + A' · dA/dθ · x + A'A · dx/dθ = dA'/dθ · b + A' · db/dθ
7474
# Rearranging:
75-
# A'A · dx/dθ = A' · db/dθ + dA'/dθ · (b - Ax) - A' · dA/dθ · x
75+
# A'A · dx/dθ = A' · db/dθ + dA'/dθ · (b - Ax) - A' · dA/dθ · x
7676

7777
# Solve the primal problem first
7878
cache.dual_u0_cache .= cache.linear_cache.u
@@ -124,15 +124,11 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
124124
end
125125
end
126126

127-
# Solve A'A · dx/dθ = rhs for each partial
128-
# Create a cache for the normal equations and reuse the factorization
129-
AtA = A' * A
130-
normal_prob = LinearProblem(AtA, rhs_list[1])
131-
normal_cache = init(normal_prob, alg, args...; kwargs...)
127+
A_adj = A'
132128

133129
for i in eachindex(rhs_list)
134-
normal_cache.b .= rhs_list[i]
135-
rhs_list[i] .= solve!(normal_cache).u
130+
cache.linear_cache.b .= A_adj \ rhs_list[i]
131+
rhs_list[i] .= solve!(cache.linear_cache, alg, args...; kwargs...).u
136132
end
137133

138134
cache.linear_cache.b .= cache.primal_b_cache

0 commit comments

Comments
 (0)