1
1
needs_concrete_A (alg:: DefaultLinearSolver ) = true
2
2
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3
- T13, T14, T15, T16, T17, T18}
3
+ T13, T14, T15, T16, T17, T18, T19 }
4
4
LUFactorization:: T1
5
5
QRFactorization:: T2
6
6
DiagonalFactorization:: T3
@@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
19
19
NormalCholeskyFactorization:: T16
20
20
AppleAccelerateLUFactorization:: T17
21
21
MKLLUFactorization:: T18
22
+ QRFactorizationPivoted:: T19
22
23
end
23
24
24
25
# Legacy fallback
@@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
168
169
(A === nothing ? eltype (b) <: Union{Float32, Float64} :
169
170
eltype (A) <: Union{Float32, Float64} )
170
171
DefaultAlgorithmChoice. RFLUFactorization
171
- # elseif A === nothing || A isa Matrix
172
- # alg = FastLUFactorization()
172
+ # elseif A === nothing || A isa Matrix
173
+ # alg = FastLUFactorization()
173
174
elseif usemkl && (A === nothing ? eltype (b) <: Union{Float32, Float64} :
174
175
eltype (A) <: Union{Float32, Float64} )
175
176
DefaultAlgorithmChoice. MKLLUFactorization
@@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
199
200
elseif assump. condition === OperatorCondition. WellConditioned
200
201
DefaultAlgorithmChoice. NormalCholeskyFactorization
201
202
elseif assump. condition === OperatorCondition. IllConditioned
202
- DefaultAlgorithmChoice. QRFactorization
203
+ if is_underdetermined (A)
204
+ # Underdetermined
205
+ DefaultAlgorithmChoice. QRFactorizationPivoted
206
+ else
207
+ DefaultAlgorithmChoice. QRFactorization
208
+ end
203
209
elseif assump. condition === OperatorCondition. VeryIllConditioned
204
- DefaultAlgorithmChoice. QRFactorization
210
+ if is_underdetermined (A)
211
+ # Underdetermined
212
+ DefaultAlgorithmChoice. QRFactorizationPivoted
213
+ else
214
+ DefaultAlgorithmChoice. QRFactorization
215
+ end
205
216
elseif assump. condition === OperatorCondition. SuperIllConditioned
206
217
DefaultAlgorithmChoice. SVDFactorization
207
218
else
@@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol)
247
258
NormalCholeskyFactorization ()
248
259
elseif alg === :AppleAccelerateLUFactorization
249
260
AppleAccelerateLUFactorization ()
261
+ elseif alg === :QRFactorizationPivoted
262
+ @static if VERSION ≥ v " 1.7beta"
263
+ QRFactorization (ColumnNorm ())
264
+ else
265
+ QRFactorization (Val (true ))
266
+ end
250
267
else
251
268
error (" Algorithm choice symbol $alg not allowed in the default" )
252
269
end
@@ -311,6 +328,12 @@ function defaultalg_symbol(::Type{T}) where {T}
311
328
end
312
329
defaultalg_symbol (:: Type{<:GenericFactorization{typeof(ldlt!)}} ) = :LDLtFactorization
313
330
331
+ @static if VERSION >= v " 1.7"
332
+ defaultalg_symbol (:: Type{<:QRFactorization{ColumnNorm}} ) = :QRFactorizationPivoted
333
+ else
334
+ defaultalg_symbol (:: Type{<:QRFactorization{Val{true}}} ) = :QRFactorizationPivoted
335
+ end
336
+
314
337
"""
315
338
if alg.alg === DefaultAlgorithmChoice.LUFactorization
316
339
SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...))
339
362
end
340
363
ex = Expr (:if , ex. args... )
341
364
end
365
+
366
+ """
367
+ ```
368
+ elseif DefaultAlgorithmChoice.LUFactorization === cache.alg
369
+ (cache.cacheval.LUFactorization)' \\ dy
370
+ else
371
+ ...
372
+ end
373
+ ```
374
+ """
375
+ @generated function defaultalg_adjoint_eval (cache:: LinearCache , dy)
376
+ ex = :()
377
+ for alg in first .(EnumX. symbol_map (DefaultAlgorithmChoice. T))
378
+ newex = if alg in Symbol .((DefaultAlgorithmChoice. MKLLUFactorization,
379
+ DefaultAlgorithmChoice. AppleAccelerateLUFactorization,
380
+ DefaultAlgorithmChoice. RFLUFactorization))
381
+ quote
382
+ getproperty (cache. cacheval,$ (Meta. quot (alg)))[1 ]' \ dy
383
+ end
384
+ elseif alg in Symbol .((DefaultAlgorithmChoice. LUFactorization,
385
+ DefaultAlgorithmChoice. QRFactorization,
386
+ DefaultAlgorithmChoice. KLUFactorization,
387
+ DefaultAlgorithmChoice. UMFPACKFactorization,
388
+ DefaultAlgorithmChoice. LDLtFactorization,
389
+ DefaultAlgorithmChoice. SparspakFactorization,
390
+ DefaultAlgorithmChoice. BunchKaufmanFactorization,
391
+ DefaultAlgorithmChoice. CHOLMODFactorization,
392
+ DefaultAlgorithmChoice. SVDFactorization,
393
+ DefaultAlgorithmChoice. CholeskyFactorization,
394
+ DefaultAlgorithmChoice. NormalCholeskyFactorization,
395
+ DefaultAlgorithmChoice. QRFactorizationPivoted,
396
+ DefaultAlgorithmChoice. GenericLUFactorization))
397
+ quote
398
+ getproperty (cache. cacheval,$ (Meta. quot (alg)))' \ dy
399
+ end
400
+ elseif alg in Symbol .((DefaultAlgorithmChoice. KrylovJL_GMRES,))
401
+ quote
402
+ invprob = LinearSolve. LinearProblem (transpose (cache. A), dy)
403
+ solve (invprob, cache. alg;
404
+ abstol = cache. val. abstol,
405
+ reltol = cache. val. reltol,
406
+ verbose = cache. val. verbose)
407
+ end
408
+ else
409
+ quote
410
+ error (" Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling" )
411
+ end
412
+ end
413
+
414
+ ex = if ex == :()
415
+ Expr (:elseif , :(getproperty (DefaultAlgorithmChoice, $ (Meta. quot (alg))) === cache. alg. alg), newex,
416
+ :(error (" Algorithm Choice not Allowed" )))
417
+ else
418
+ Expr (:elseif , :(getproperty (DefaultAlgorithmChoice, $ (Meta. quot (alg))) === cache. alg. alg), newex, ex)
419
+ end
420
+ end
421
+ ex = Expr (:if , ex. args... )
422
+ end
0 commit comments