Skip to content

Commit a69120e

Browse files
Merge pull request #845 from hersle/check_pattern_first
Fix and speed up sparsity pattern check
2 parents dac8f53 + ce2613c commit a69120e

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

ext/LinearSolveSparseArraysExt.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ function SciMLBase.solve!(
201201
cacheval = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
202202
if alg.reuse_symbolic
203203
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
204-
if alg.check_pattern && pattern_changed(cacheval, A)
204+
if length(cacheval.nzval) != length(A.nzval) || alg.check_pattern && pattern_changed(cacheval, A)
205205
fact = lu(
206206
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
207207
nonzeros(A)),
@@ -331,7 +331,7 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization;
331331
if cache.isfresh
332332
cacheval = LinearSolve.@get_cacheval(cache, :KLUFactorization)
333333
if alg.reuse_symbolic
334-
if alg.check_pattern && pattern_changed(cacheval, A)
334+
if length(cacheval.nzval) != length(A.nzval) || alg.check_pattern && pattern_changed(cacheval, A)
335335
fact = KLU.klu(
336336
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
337337
nonzeros(A)),
@@ -455,9 +455,19 @@ function LinearSolve.pattern_changed(fact::Nothing, A::SparseArrays.SparseMatrix
455455
end
456456

457457
function LinearSolve.pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
458-
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
459-
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
460-
fact.rowval)
458+
colptr0 = fact.colptr # has 0-based indices
459+
colptr1 = SparseArrays.getcolptr(A) # has 1-based indices
460+
length(colptr0) == length(colptr1) || return true
461+
@inbounds for i in eachindex(colptr0)
462+
colptr0[i] + 1 == colptr1[i] || return true
463+
end
464+
rowval0 = fact.rowval
465+
rowval1 = SparseArrays.getrowval(A)
466+
length(rowval0) == length(rowval1) || return true
467+
@inbounds for i in eachindex(rowval0)
468+
rowval0[i] + 1 == rowval1[i] || return true
469+
end
470+
return false
461471
end
462472

463473
@static if Base.USE_GPL_LIBS

src/KLU/klu.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ See the [`klu`](@ref) docs for more information.
134134
135135
You typically should not construct this directly, instead use [`klu`](@ref).
136136
"""
137-
mutable struct KLUFactorization{Tv <: KLUTypes, Ti <: KLUITypes} <:
137+
mutable struct KLUFactorization{Tv <: KLUTypes, Ti <: KLUITypes, Tklu <: Union{klu_l_common, klu_common}} <:
138138
AbstractKLUFactorization{Tv, Ti}
139-
common::Union{klu_l_common, klu_common}
139+
common::Tklu
140140
_symbolic::Ptr{Cvoid}
141141
_numeric::Ptr{Cvoid}
142142
n::Int
@@ -146,7 +146,7 @@ mutable struct KLUFactorization{Tv <: KLUTypes, Ti <: KLUITypes} <:
146146
function KLUFactorization(n, colptr, rowval, nzval)
147147
Ti = eltype(colptr)
148148
common = _common(Ti)
149-
obj = new{eltype(nzval), Ti}(common, C_NULL, C_NULL, n, colptr, rowval, nzval)
149+
obj = new{eltype(nzval), Ti, typeof(common)}(common, C_NULL, C_NULL, n, colptr, rowval, nzval)
150150
function f(klu)
151151
_free_symbolic(klu)
152152
_free_numeric(klu)

0 commit comments

Comments
 (0)