diff --git a/src/symmetric.jl b/src/symmetric.jl index 2cc74763..caf6ed28 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -461,9 +461,34 @@ issymmetric(A::Hermitian{<:Real}) = true issymmetric(A::Hermitian{<:Complex}) = isreal(A) issymmetric(A::Symmetric) = true -# check if the symmetry is known from the type -_issymmetric(::Union{SymSymTri, Hermitian{<:Real}}) = true -_issymmetric(::Any) = false +""" + issymmetrictype(T::Type) + +Return whether every instance `x` of the type `T` satisfies `issymmetric(x) == tue`, +that is, the fact that the instance is symmetric is known from its type. + +!!! note + An instance `x::T` may still be symmetric when `issymmetrictype(T)` returns `false`. +""" +issymmetrictype(::Type) = false +issymmetrictype(::Type{<:Union{Symmetric,Hermitian{<:Real}}}) = true +issymmetrictype(::Type{<:Real}) = true +issymmetrictype(::Type{<:AbstractFloat}) = false +issymmetrictype(::Type{Complex{T}}) where {T} = issymmetrictype(T) + +""" + ishermitiantype(T::Type) + +Return whether every instance `x` of the type `T` satisfies `ishermitian(x) == tue`, +that is, the fact that the instance is hermitian is known from its type. + +!!! note + An instance `x::T` may still be hermitian when `ishermitiantype(T)` returns `false`. +""" +ishermitiantype(::Type) = false +ishermitiantype(::Type{<:Union{Symmetric{<:Real},Hermitian}}) = true +ishermitiantype(::Type{<:Real}) = true +ishermitiantype(::Type{<:AbstractFloat}) = false adjoint(A::Hermitian) = A transpose(A::Symmetric) = A diff --git a/src/tridiag.jl b/src/tridiag.jl index a0e3d821..8e12b8a3 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -111,14 +111,14 @@ function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal} checksquare(A) du = diag(A, 1) d = diag(A) - if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, -1))) + if !(issymmetrictype(typeof(A)) || _checksymmetric(d, du, diag(A, -1))) throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) end return SymTri(d, du) end _checksymmetric(d, du, dl) = all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d) -_checksymmetric(A::AbstractMatrix) = _issymmetric(A) || _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1)) +_checksymmetric(A::AbstractMatrix) = issymmetrictype(typeof(A)) || _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1)) SymTridiagonal{T,V}(S::SymTridiagonal{T,V}) where {T,V<:AbstractVector{T}} = S SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} = diff --git a/test/symmetric.jl b/test/symmetric.jl index 707b392d..91f0d37d 100644 --- a/test/symmetric.jl +++ b/test/symmetric.jl @@ -1343,4 +1343,26 @@ end @test_throws msg LinearAlgebra.fillband!(Symmetric(A), 2, 0, 1) end +@testset "issymmetrictype/ishermitiantype" begin + fsym(x) = Val(LinearAlgebra.issymmetrictype(typeof(x))) + @test @inferred(fsym(Symmetric(ones(2,2)))) == Val(true) + @test @inferred(fsym(Symmetric(ones(ComplexF64,2,2)))) == Val(true) + @test @inferred(fsym(Hermitian(ones(2,2)))) == Val(true) + @test @inferred(fsym(Hermitian(ones(ComplexF64,2,2)))) == Val(false) + @test @inferred(fsym(1)) == Val(true) + @test @inferred(fsym(1.0)) == Val(false) + @test @inferred(fsym(complex(1))) == Val(true) + @test @inferred(fsym(complex(1.0))) == Val(false) + + fherm(x) = Val(LinearAlgebra.ishermitiantype(typeof(x))) + @test @inferred(fherm(Symmetric(ones(2,2)))) == Val(true) + @test @inferred(fherm(Symmetric(ones(ComplexF64,2,2)))) == Val(false) + @test @inferred(fherm(Hermitian(ones(2,2)))) == Val(true) + @test @inferred(fherm(Hermitian(ones(ComplexF64,2,2)))) == Val(true) + @test @inferred(fherm(1)) == Val(true) + @test @inferred(fherm(1.0)) == Val(false) + @test @inferred(fherm(complex(1))) == Val(false) + @test @inferred(fherm(complex(1.0))) == Val(false) +end + end # module TestSymmetric diff --git a/test/tridiag.jl b/test/tridiag.jl index 7a4d78b9..520192b3 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1155,10 +1155,6 @@ end @testset "SymTridiagonal from Symmetric" begin S = Symmetric(reshape(1:9, 3, 3)) - @testset "helper functions" begin - @test LinearAlgebra._issymmetric(S) - @test !LinearAlgebra._issymmetric(Array(S)) - end ST = SymTridiagonal(S) @test ST == SymTridiagonal(diag(S), diag(S,1)) S = Symmetric(Tridiagonal(1:3, 1:4, 1:3))