Skip to content

Commit f8398c3

Browse files
committed
factor out A_adj
1 parent 94358bd commit f8398c3

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
8787
∂_A = cache.partials_A
8888
∂_b = cache.partials_b
8989
A = nodual_value(cache.dual_A)
90+
A_adj = A'
9091
b = cache.primal_b_cache
9192
residual = b - A * u # residual r = b - Ax
9293

@@ -110,7 +111,7 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
110111
for i in eachindex(rhs_list)
111112
if !isnothing(b_list)
112113
# A' · db/dθ
113-
rhs_list[i] .= A' * b_list[i]
114+
rhs_list[i] .= A_adj * b_list[i]
114115
else
115116
fill!(rhs_list[i], 0)
116117
end
@@ -120,12 +121,10 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
120121
rhs_list[i] .+= A_list[i]' * residual
121122
# Subtract A' · dA/dθ · x
122123
temp = A_list[i] * u
123-
rhs_list[i] .-= A' * temp
124+
rhs_list[i] .-= A_adj * temp
124125
end
125126
end
126127

127-
A_adj = A'
128-
129128
for i in eachindex(rhs_list)
130129
cache.linear_cache.b .= A_adj \ rhs_list[i]
131130
rhs_list[i] .= solve!(cache.linear_cache, alg, args...; kwargs...).u

0 commit comments

Comments
 (0)