Skip to content

Commit 6fcfbef

Browse files
Merge branch 'main' into dualreinit
2 parents 81b4b4c + 5fa4efd commit 6fcfbef

12 files changed

+125
-54
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "3.48.0"
4+
version = "3.48.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -86,7 +86,7 @@ ArrayInterface = "7.17"
8686
BandedMatrices = "1.8"
8787
BlockDiagonals = "0.2"
8888
CUDA = "5.5"
89-
CUDSS = "0.4, 0.6.1"
89+
CUDSS = "0.6.3"
9090
CUSOLVERRF = "0.2.6"
9191
ChainRulesCore = "1.25"
9292
CliqueTrees = "1.11.0"

ext/LinearSolveCUDAExt.jl

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
55
DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
66
needs_concrete_A,
77
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
8-
CudaOffloadFactorization, CudaOffloadLUFactorization,
9-
CudaOffloadQRFactorization,
8+
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
109
CUDAOffload32MixedLUFactorization,
11-
SparspakFactorization, KLUFactorization, UMFPACKFactorization,
12-
LinearVerbosity
10+
SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
1311
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
1412
using SciMLBase: AbstractSciMLOperator
1513

@@ -19,23 +17,30 @@ function LinearSolve.is_cusparse(A::Union{
1917
CUDA.CUSPARSE.CuSparseMatrixCSR, CUDA.CUSPARSE.CuSparseMatrixCSC})
2018
true
2119
end
20+
LinearSolve.is_cusparse_csr(::CUDA.CUSPARSE.CuSparseMatrixCSR) = true
21+
LinearSolve.is_cusparse_csc(::CUDA.CUSPARSE.CuSparseMatrixCSC) = true
2222

2323
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
2424
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
2525
if LinearSolve.cudss_loaded(A)
2626
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
2727
else
28-
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")
28+
if !LinearSolve.ALREADY_WARNED_CUDSS[]
29+
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
30+
LinearSolve.ALREADY_WARNED_CUDSS[] = true
31+
end
32+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
2933
end
3034
end
3135

32-
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti}, b,
33-
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
36+
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSC, b,
37+
assump::OperatorAssumptions{Bool})
3438
if LinearSolve.cudss_loaded(A)
35-
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
39+
@warn("CUDSS.jl does not support CuSparseMatrixCSC for LU Factorizations, consider using CuSparseMatrixCSR instead. Falling back to Krylov", maxlog=1)
3640
else
37-
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
41+
@warn("CuSparseMatrixCSC does not support LU Factorization falling back to Krylov. Consider using CUDSS.jl together with CuSparseMatrixCSR", maxlog=1)
3842
end
43+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
3944
end
4045

4146
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
@@ -45,13 +50,6 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
4550
nothing
4651
end
4752

48-
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSC)
49-
if !LinearSolve.cudss_loaded(A)
50-
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
51-
end
52-
nothing
53-
end
54-
5553
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
5654
kwargs...)
5755
if cache.isfresh
@@ -66,15 +64,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
6664
SciMLBase.build_linear_solution(alg, y, nothing, cache)
6765
end
6866

69-
function LinearSolve.init_cacheval(
70-
alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
67+
function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
7168
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
7269
assumptions::OperatorAssumptions)
7370
# Check if CUDA is functional before creating CUDA arrays
7471
if !CUDA.functional()
7572
return nothing
7673
end
77-
74+
7875
T = eltype(A)
7976
noUnitT = typeof(zero(T))
8077
luT = LinearAlgebra.lutype(noUnitT)
@@ -102,7 +99,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
10299
if !CUDA.functional()
103100
return nothing
104101
end
105-
102+
106103
qr(CUDA.CuArray(A))
107104
end
108105

@@ -119,42 +116,35 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
119116
SciMLBase.build_linear_solution(alg, y, nothing, cache)
120117
end
121118

122-
function LinearSolve.init_cacheval(
123-
alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
119+
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
124120
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
125121
assumptions::OperatorAssumptions)
126122
qr(CUDA.CuArray(A))
127123
end
128124

129125
function LinearSolve.init_cacheval(
130126
::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
131-
Pl, Pr, maxiters::Int, abstol, reltol,
132-
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
127+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
133128
nothing
134129
end
135130

136131
function LinearSolve.init_cacheval(
137132
::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
138-
Pl, Pr, maxiters::Int, abstol, reltol,
139-
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
133+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
140134
nothing
141135
end
142136

143137
function LinearSolve.init_cacheval(
144138
::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
145-
Pl, Pr, maxiters::Int, abstol, reltol,
146-
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
139+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
147140
nothing
148141
end
149142

150143
# Mixed precision CUDA LU implementation
151-
function SciMLBase.solve!(
152-
cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
144+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
153145
kwargs...)
154146
if cache.isfresh
155-
fact, A_gpu_f32,
156-
b_gpu_f32,
157-
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
147+
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
158148
# Compute 32-bit type on demand and convert
159149
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
160150
A_f32 = T32.(cache.A)
@@ -163,14 +153,12 @@ function SciMLBase.solve!(
163153
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
164154
cache.isfresh = false
165155
end
166-
fact, A_gpu_f32,
167-
b_gpu_f32,
168-
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
169-
156+
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
157+
170158
# Compute types on demand for conversions
171159
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
172160
Torig = eltype(cache.u)
173-
161+
174162
# Convert b to Float32, solve, then convert back to original precision
175163
b_f32 = T32.(cache.b)
176164
copyto!(b_gpu_f32, b_f32)

ext/LinearSolveCUDSSExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ using LinearSolve: LinearSolve, cudss_loaded
44
using CUDSS
55

66
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSR) = true
7-
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSC) = true
87

98
end

ext/LinearSolveCUSOLVERRFExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LinearSolveCUSOLVERRFExt
22

3-
using LinearSolve: LinearSolve, @get_cacheval, pattern_changed, OperatorAssumptions
3+
using LinearSolve: LinearSolve, @get_cacheval, pattern_changed, OperatorAssumptions, LinearVerbosity
44
using CUSOLVERRF: CUSOLVERRF, RFLU, CUDA
55
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz
66
using CUSOLVERRF.CUDA.CUSPARSE: CuSparseMatrixCSR

ext/LinearSolveForwardDiffExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactoriza
200200
return __init(prob, alg, args...; kwargs...)
201201
end
202202

203+
# Opt out for SparspakFactorization
204+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SparspakFactorization, args...; kwargs...)
205+
return __init(prob, alg, args...; kwargs...)
206+
end
207+
203208
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::DefaultLinearSolver, args...; kwargs...)
204209
if alg.alg === DefaultAlgorithmChoice.GenericLUFactorization
205210
return __init(prob, alg, args...; kwargs...)

ext/LinearSolveSparseArraysExt.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function LinearSolve.init_cacheval(
129129
maxiters::Int, abstol, reltol,
130130
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
131131
if LinearSolve.is_cusparse(A)
132-
ArrayInterface.lu_instance(A)
132+
LinearSolve.cudss_loaded(A) ? ArrayInterface.lu_instance(A) : nothing
133133
else
134134
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(
135135
zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
@@ -141,7 +141,7 @@ function LinearSolve.init_cacheval(
141141
maxiters::Int, abstol, reltol,
142142
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
143143
if LinearSolve.is_cusparse(A)
144-
ArrayInterface.lu_instance(A)
144+
LinearSolve.cudss_loaded(A) ? ArrayInterface.lu_instance(A) : nothing
145145
else
146146
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(
147147
zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
@@ -344,7 +344,13 @@ function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
344344
Symmetric{T, <:AbstractSparseArray{T}}}, b, u, Pl, Pr,
345345
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
346346
assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
347-
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
347+
if LinearSolve.is_cusparse_csc(A)
348+
nothing
349+
elseif LinearSolve.is_cusparse_csr(A) && !LinearSolve.cudss_loaded(A)
350+
nothing
351+
else
352+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
353+
end
348354
end
349355

350356
# Specialize QR for the non-square case

src/LinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ ALREADY_WARNED_CUDSS = Ref{Bool}(false)
478478
error_no_cudss_lu(A) = nothing
479479
cudss_loaded(A) = false
480480
is_cusparse(A) = false
481+
is_cusparse_csr(A) = false
482+
is_cusparse_csc(A) = false
481483

482484
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
483485
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, ButterflyFactorization,

src/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ function __init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
357357
u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b)
358358

359359
# Guard against type mismatch for user-specified reltol/abstol
360-
reltol = real(eltype(prob.b))(reltol)
361-
abstol = real(eltype(prob.b))(abstol)
360+
reltol = real(eltype(prob.b))(SciMLBase.value(reltol))
361+
abstol = real(eltype(prob.b))(SciMLBase.value(abstol))
362362

363363
precs = if hasproperty(alg, :precs)
364364
isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs

src/factorization.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,13 @@ end
395395
function init_cacheval(
396396
alg::CholeskyFactorization, A::AbstractArray{<:BLASELTYPES}, b, u, Pl, Pr,
397397
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
398-
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
398+
if LinearSolve.is_cusparse_csc(A)
399+
nothing
400+
elseif LinearSolve.is_cusparse_csr(A) && !LinearSolve.cudss_loaded(A)
401+
nothing
402+
else
403+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
404+
end
399405
end
400406

401407
const PREALLOCATED_CHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot())

test/forwarddiff_overloads.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ForwardDiff
33
using Test
44
using SparseArrays
55
using ComponentArrays
6+
using Sparspak
67

78
function h(p)
89
(A = [p[1] p[2]+1 p[2]^3;
@@ -203,6 +204,12 @@ prob = LinearProblem(A, b)
203204

204205
@test init(prob) isa LinearSolve.LinearCache
205206

207+
# Test that SparspakFactorization doesn't create a DualLinearCache
208+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
209+
210+
prob = LinearProblem(sparse(A), b)
211+
@test init(prob, SparspakFactorization()) isa LinearSolve.LinearCache
212+
206213
# Test ComponentArray with ForwardDiff (Issue SciML/DifferentialEquations.jl#1110)
207214
# This tests that ArrayInterface.restructure preserves ComponentArray structure
208215

0 commit comments

Comments
 (0)