@@ -3,8 +3,9 @@ module LinearSolveEnzymeExt
33using LinearSolve
44using LinearSolve. LinearAlgebra
55using EnzymeCore
6+ using EnzymeCore: EnzymeRules
67
7- function EnzymeCore . EnzymeRules. forward (config:: EnzymeCore. EnzymeRules.FwdConfigWidth{1} ,
8+ function EnzymeRules. forward (config:: EnzymeRules.FwdConfigWidth{1} ,
89 func:: Const{typeof(LinearSolve.init)} , :: Type{RT} , prob:: EnzymeCore.Annotation{LP} ,
910 alg:: Const ; kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
1011 @assert ! (prob isa Const)
@@ -19,26 +20,20 @@ function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfig
1920 dres = func. val (prob. dval, alg. val; kwargs... )
2021 dres. b .= res. b == dres. b ? zero (dres. b) : dres. b
2122 dres. A .= res. A == dres. A ? zero (dres. A) : dres. A
22- if RT <: DuplicatedNoNeed
23- return dres
24- elseif RT <: Duplicated
25- return Duplicated (res, dres)
26- end
27- error (" Unsupported return type $RT " )
2823
2924 if EnzymeRules. needs_primal (config) && EnzymeRules. needs_shadow (config)
30- Duplicated (res, dres)
25+ return Duplicated (res, dres)
3126 elseif EnzymeRules. needs_shadow (config)
32- dres
27+ return dres
3328 elseif EnzymeRules. needs_primal (config)
34- res
29+ return res
3530 else
36- nothing
31+ return nothing
3732 end
3833end
3934
40- function EnzymeCore . EnzymeRules. forward (
41- config:: EnzymeCore. EnzymeRules.FwdConfigWidth{1} , func:: Const{typeof(LinearSolve.solve!)} ,
35+ function EnzymeRules. forward (
36+ config:: EnzymeRules.FwdConfigWidth{1} , func:: Const{typeof(LinearSolve.solve!)} ,
4237 :: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ;
4338 kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
4439 @assert ! (linsolve isa Const)
@@ -66,17 +61,17 @@ function EnzymeCore.EnzymeRules.forward(
6661 linsolve. val. b = b
6762
6863 if EnzymeRules. needs_primal (config) && EnzymeRules. needs_shadow (config)
69- Duplicated (res, dres)
64+ return Duplicated (res, dres)
7065 elseif EnzymeRules. needs_shadow (config)
71- dres
66+ return dres
7267 elseif EnzymeRules. needs_primal (config)
73- res
68+ return res
7469 else
75- nothing
70+ return nothing
7671 end
7772end
7873
79- function EnzymeCore . EnzymeRules. augmented_primal (
74+ function EnzymeRules. augmented_primal (
8075 config, func:: Const{typeof(LinearSolve.init)} ,
8176 :: Type{RT} , prob:: EnzymeCore.Annotation{LP} , alg:: Const ;
8277 kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
@@ -111,10 +106,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
111106 (dval. b for dval in prob. dval)
112107 end
113108
114- return EnzymeCore . EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
109+ return EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
115110end
116111
117- function EnzymeCore . EnzymeRules. reverse (
112+ function EnzymeRules. reverse (
118113 config, func:: Const{typeof(LinearSolve.init)} , :: Type{RT} ,
119114 cache, prob:: EnzymeCore.Annotation{LP} , alg:: Const ;
120115 kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
148143# y=inv(A) B
149144# dA −= z y^T
150145# dB += z, where z = inv(A^T) dy
151- function EnzymeCore . EnzymeRules. augmented_primal (
146+ function EnzymeRules. augmented_primal (
152147 config, func:: Const{typeof(LinearSolve.solve!)} ,
153148 :: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ;
154149 kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
@@ -201,10 +196,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
201196 cachesolve = deepcopy (linsolve. val)
202197
203198 cache = (copy (res. u), resvals, cachesolve, dAs, dbs)
204- return EnzymeCore . EnzymeRules. AugmentedReturn (res, dres, cache)
199+ return EnzymeRules. AugmentedReturn (res, dres, cache)
205200end
206201
207- function EnzymeCore . EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} ,
202+ function EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} ,
208203 :: Type{RT} , cache, linsolve:: EnzymeCore.Annotation{LP} ;
209204 kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
210205 y, dys, _linsolve, dAs, dbs = cache
0 commit comments