Skip to content

Commit 72bc014

Browse files
committed
Support StaticArrays Properly
1 parent e712202 commit 72bc014

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.20.0"
4+
version = "2.20.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -26,6 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2626
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2727
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2828
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
29+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2930
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3031

3132
[weakdeps]
@@ -48,14 +49,14 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
4849
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
4950
LinearSolveCUDAExt = "CUDA"
5051
LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"]
52+
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
5153
LinearSolveHYPREExt = "HYPRE"
5254
LinearSolveIterativeSolversExt = "IterativeSolvers"
5355
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
5456
LinearSolveKrylovKitExt = "KrylovKit"
5557
LinearSolveMetalExt = "Metal"
5658
LinearSolvePardisoExt = "Pardiso"
5759
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
58-
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
5960

6061
[compat]
6162
Aqua = "0.8"
@@ -102,6 +103,7 @@ Setfield = "1"
102103
SparseArrays = "1.9"
103104
Sparspak = "0.3.6"
104105
Test = "1"
106+
StaticArraysCore = "1"
105107
UnPack = "1"
106108
julia = "1.9"
107109

src/LinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ PrecompileTools.@recompile_invalidations begin
2626
using Requires
2727
import InteractiveUtils
2828

29+
import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
30+
2931
using LinearAlgebra: BlasInt, LU
3032
using LinearAlgebra.LAPACK: require_one_based_indexing,
3133
chkfinite, chkstride1,

src/default.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ function defaultalg(A, b, assump::OperatorAssumptions)
175175
DefaultAlgorithmChoice.LUFactorization
176176
end
177177

178+
# For static arrays GMRES allocates a lot. Use factorization
179+
elseif A isa StaticArray
180+
DefaultAlgorithmChoice.LUFactorization
181+
178182
# This catches the cases where a factorization overload could exist
179183
# For example, BlockBandedMatrix
180184
elseif A !== nothing && ArrayInterface.isstructured(A)
@@ -186,6 +190,9 @@ function defaultalg(A, b, assump::OperatorAssumptions)
186190
end
187191
elseif assump.condition === OperatorCondition.WellConditioned
188192
DefaultAlgorithmChoice.NormalCholeskyFactorization
193+
elseif A isa StaticArray
194+
# Static Array doesn't have QR() \ b defined
195+
return DefaultAlgorithmChoice.SVDFactorization
189196
elseif assump.condition === OperatorCondition.IllConditioned
190197
if is_underdetermined(A)
191198
# Underdetermined

src/factorization.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ end
1010

1111
_ldiv!(x, A, b) = ldiv!(x, A, b)
1212

13+
_ldiv!(x::MVector, A, b::SVector) = (x .= A \ b)
14+
_ldiv!(::SVector, A, b::SVector) = (A \ b)
15+
1316
function _ldiv!(x::Vector, A::Factorization, b::Vector)
1417
# workaround https://github.com/JuliaLang/julia/issues/43507
1518
# Fallback if working with non-square matrices
@@ -88,6 +91,8 @@ function do_factorization(alg::LUFactorization, A, b, u)
8891
if A isa AbstractSparseMatrixCSC
8992
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
9093
check = false)
94+
elseif !ArrayInterface.can_setindex(typeof(A))
95+
fact = lu(A, alg.pivot, check = false)
9196
else
9297
fact = lu!(A, alg.pivot, check = false)
9398
end
@@ -172,10 +177,14 @@ end
172177

173178
function do_factorization(alg::QRFactorization, A, b, u)
174179
A = convert(AbstractMatrix, A)
175-
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
176-
fact = qr!(A, alg.pivot)
180+
if ArrayInterface.can_setindex(typeof(A))
181+
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
182+
fact = qr!(A, alg.pivot)
183+
else
184+
fact = qr(A) # CUDA.jl does not allow other args!
185+
end
177186
else
178-
fact = qr(A) # CUDA.jl does not allow other args!
187+
fact = qr(A, alg.pivot)
179188
end
180189
return fact
181190
end
@@ -372,11 +381,15 @@ SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
372381

373382
function do_factorization(alg::SVDFactorization, A, b, u)
374383
A = convert(AbstractMatrix, A)
375-
fact = svd!(A; full = alg.full, alg = alg.alg)
384+
if ArrayInterface.can_setindex(typeof(A))
385+
fact = svd!(A; alg.full, alg.alg)
386+
else
387+
fact = svd(A; alg.full)
388+
end
376389
return fact
377390
end
378391

379-
function init_cacheval(alg::SVDFactorization, A::Matrix, b, u, Pl, Pr,
392+
function init_cacheval(alg::SVDFactorization, A::Union{Matrix, SMatrix}, b, u, Pl, Pr,
380393
maxiters::Int, abstol, reltol, verbose::Bool,
381394
assumptions::OperatorAssumptions)
382395
ArrayInterface.svd_instance(convert(AbstractMatrix, A))
@@ -1354,6 +1367,11 @@ function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int,
13541367
end
13551368
end
13561369

1370+
function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr,
1371+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1372+
nothing
1373+
end
1374+
13571375
function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs...)
13581376
A = cache.A
13591377
if cache.isfresh

0 commit comments

Comments
 (0)