@@ -5,11 +5,9 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
55 DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
66 needs_concrete_A,
77 error_no_cudss_lu, init_cacheval, OperatorAssumptions,
8- CudaOffloadFactorization, CudaOffloadLUFactorization,
9- CudaOffloadQRFactorization,
8+ CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
109 CUDAOffload32MixedLUFactorization,
11- SparspakFactorization, KLUFactorization, UMFPACKFactorization,
12- LinearVerbosity
10+ SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
1311using LinearSolve. LinearAlgebra, LinearSolve. SciMLBase, LinearSolve. ArrayInterface
1412using SciMLBase: AbstractSciMLOperator
1513
@@ -19,23 +17,30 @@ function LinearSolve.is_cusparse(A::Union{
1917 CUDA. CUSPARSE. CuSparseMatrixCSR, CUDA. CUSPARSE. CuSparseMatrixCSC})
2018 true
2119end
20+ LinearSolve. is_cusparse_csr (:: CUDA.CUSPARSE.CuSparseMatrixCSR ) = true
21+ LinearSolve. is_cusparse_csc (:: CUDA.CUSPARSE.CuSparseMatrixCSC ) = true
2222
2323function LinearSolve. defaultalg (A:: CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti} , b,
2424 assump:: OperatorAssumptions{Bool} ) where {Tv, Ti}
2525 if LinearSolve. cudss_loaded (A)
2626 LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. LUFactorization)
2727 else
28- error (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library." )
28+ if ! LinearSolve. ALREADY_WARNED_CUDSS[]
29+ @warn (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov" )
30+ LinearSolve. ALREADY_WARNED_CUDSS[] = true
31+ end
32+ LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. KrylovJL_GMRES)
2933 end
3034end
3135
32- function LinearSolve. defaultalg (A:: CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti} , b,
33- assump:: OperatorAssumptions{Bool} ) where {Tv, Ti}
36+ function LinearSolve. defaultalg (A:: CUDA.CUSPARSE.CuSparseMatrixCSC , b,
37+ assump:: OperatorAssumptions{Bool} )
3438 if LinearSolve. cudss_loaded (A)
35- LinearSolve . DefaultLinearSolver (LinearSolve . DefaultAlgorithmChoice . LUFactorization )
39+ @warn ( " CUDSS.jl does not support CuSparseMatrixCSC for LU Factorizations, consider using CuSparseMatrixCSR instead. Falling back to Krylov " , maxlog = 1 )
3640 else
37- error ( " CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library. " )
41+ @warn ( " CuSparseMatrixCSC does not support LU Factorization falling back to Krylov. Consider using CUDSS.jl together with CuSparseMatrixCSR " , maxlog = 1 )
3842 end
43+ LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. KrylovJL_GMRES)
3944end
4045
4146function LinearSolve. error_no_cudss_lu (A:: CUDA.CUSPARSE.CuSparseMatrixCSR )
@@ -45,13 +50,6 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
4550 nothing
4651end
4752
48- function LinearSolve. error_no_cudss_lu (A:: CUDA.CUSPARSE.CuSparseMatrixCSC )
49- if ! LinearSolve. cudss_loaded (A)
50- error (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library." )
51- end
52- nothing
53- end
54-
5553function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: CudaOffloadLUFactorization ;
5654 kwargs... )
5755 if cache. isfresh
@@ -66,15 +64,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
6664 SciMLBase. build_linear_solution (alg, y, nothing , cache)
6765end
6866
69- function LinearSolve. init_cacheval (
70- alg:: CudaOffloadLUFactorization , A:: AbstractArray , b, u, Pl, Pr,
67+ function LinearSolve. init_cacheval (alg:: CudaOffloadLUFactorization , A:: AbstractArray , b, u, Pl, Pr,
7168 maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} ,
7269 assumptions:: OperatorAssumptions )
7370 # Check if CUDA is functional before creating CUDA arrays
7471 if ! CUDA. functional ()
7572 return nothing
7673 end
77-
74+
7875 T = eltype (A)
7976 noUnitT = typeof (zero (T))
8077 luT = LinearAlgebra. lutype (noUnitT)
@@ -102,7 +99,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
10299 if ! CUDA. functional ()
103100 return nothing
104101 end
105-
102+
106103 qr (CUDA. CuArray (A))
107104end
108105
@@ -119,42 +116,35 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
119116 SciMLBase. build_linear_solution (alg, y, nothing , cache)
120117end
121118
122- function LinearSolve. init_cacheval (
123- alg:: CudaOffloadFactorization , A:: AbstractArray , b, u, Pl, Pr,
119+ function LinearSolve. init_cacheval (alg:: CudaOffloadFactorization , A:: AbstractArray , b, u, Pl, Pr,
124120 maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} ,
125121 assumptions:: OperatorAssumptions )
126122 qr (CUDA. CuArray (A))
127123end
128124
129125function LinearSolve. init_cacheval (
130126 :: SparspakFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
131- Pl, Pr, maxiters:: Int , abstol, reltol,
132- verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
127+ Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
133128 nothing
134129end
135130
136131function LinearSolve. init_cacheval (
137132 :: KLUFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
138- Pl, Pr, maxiters:: Int , abstol, reltol,
139- verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
133+ Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
140134 nothing
141135end
142136
143137function LinearSolve. init_cacheval (
144138 :: UMFPACKFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
145- Pl, Pr, maxiters:: Int , abstol, reltol,
146- verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
139+ Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
147140 nothing
148141end
149142
150143# Mixed precision CUDA LU implementation
151- function SciMLBase. solve! (
152- cache:: LinearSolve.LinearCache , alg:: CUDAOffload32MixedLUFactorization ;
144+ function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: CUDAOffload32MixedLUFactorization ;
153145 kwargs... )
154146 if cache. isfresh
155- fact, A_gpu_f32,
156- b_gpu_f32,
157- u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
147+ fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
158148 # Compute 32-bit type on demand and convert
159149 T32 = eltype (cache. A) <: Complex ? ComplexF32 : Float32
160150 A_f32 = T32 .(cache. A)
@@ -163,14 +153,12 @@ function SciMLBase.solve!(
163153 cache. cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
164154 cache. isfresh = false
165155 end
166- fact, A_gpu_f32,
167- b_gpu_f32,
168- u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
169-
156+ fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
157+
170158 # Compute types on demand for conversions
171159 T32 = eltype (cache. A) <: Complex ? ComplexF32 : Float32
172160 Torig = eltype (cache. u)
173-
161+
174162 # Convert b to Float32, solve, then convert back to original precision
175163 b_f32 = T32 .(cache. b)
176164 copyto! (b_gpu_f32, b_f32)
0 commit comments