diff --git a/LICENSE b/LICENSE index 28e07032..b9a9755c 100644 --- a/LICENSE +++ b/LICENSE @@ -22,3 +22,35 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +Portions of this code are derived from SciPy and are licensed under +the Scipy License: + +> Copyright (c) 2001, 2002 Enthought, Inc. +> All rights reserved. + +> Copyright (c) 2003-2012 SciPy Developers. +> All rights reserved. + +> Redistribution and use in source and binary forms, with or without +> modification, are permitted provided that the following conditions are met: + +> a. Redistributions of source code must retain the above copyright notice, +> this list of conditions and the following disclaimer. +> b. Redistributions in binary form must reproduce the above copyright +> notice, this list of conditions and the following disclaimer in the +> documentation and/or other materials provided with the distribution. +> c. Neither the name of Enthought nor the names of the SciPy Developers +> may be used to endorse or promote products derived from this software +> without specific prior written permission. +> +> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +> AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +> IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +> ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS +> BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +> OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +> SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +> INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +> CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +> ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +> THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index dea48cb9..709e35e0 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # SpecialFunctions.jl Special mathematical functions in Julia, including Bessel, Hankel, Airy, error, Dawson, sine and cosine integrals, -eta, zeta, digamma, inverse digamma, trigamma, and polygamma functions. +eta, zeta, digamma, inverse digamma, trigamma, polygamma, and Lambert W functions. Most of these functions were formerly part of Base. Note: On Julia 0.7, this package downloads and/or builds diff --git a/docs/src/index.md b/docs/src/index.md index 4dead14b..d2f29f87 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -40,6 +40,7 @@ libraries. | [`besselix(nu,z)`](@ref SpecialFunctions.besselix) | scaled modified Bessel function of the first kind of order `nu` at `z` | | [`besselk(nu,z)`](@ref SpecialFunctions.besselk) | modified [Bessel function](https://en.wikipedia.org/wiki/Bessel_function) of the second kind of order `nu` at `z` | | [`besselkx(nu,z)`](@ref SpecialFunctions.besselkx) | scaled modified Bessel function of the second kind of order `nu` at `z` | +| [`lambertw(z,k)`](@ref SpecialFunctions.lambertw) | `k`th branch of the Lambert W function at `z` | ## Installation diff --git a/docs/src/special.md b/docs/src/special.md index f8bb48b1..c8e06e8b 100644 --- a/docs/src/special.md +++ b/docs/src/special.md @@ -46,4 +46,7 @@ SpecialFunctions.besselk SpecialFunctions.besselkx SpecialFunctions.eta SpecialFunctions.zeta +SpecialFunctions.lambertw +SpecialFunctions.lambertwbp +SpecialFunctions.omega ``` diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index 5994a194..7ede35a4 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -71,10 +71,20 @@ end export sinint, cosint +export lambertw, lambertwbp + +const omega_const_bf_ = Ref{BigFloat}() +function __init__() + # allocate storage for this BigFloat constant each time this module is loaded + omega_const_bf_[] = + parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194") +end + include("bessel.jl") include("erf.jl") include("sincosint.jl") include("gamma.jl") +include("lambertw.jl") include("deprecated.jl") end # module diff --git a/src/lambertw.jl b/src/lambertw.jl new file mode 100644 index 00000000..5fdd9e74 --- /dev/null +++ b/src/lambertw.jl @@ -0,0 +1,358 @@ +import Base: convert + +using Compat +import Compat.MathConstants # For clarity, we use MathConstants.e for Euler's number + +#### Lambert W function #### + +# Use Halley's root-finding method to find +# x = lambertw(z) with initial point x. +function _lambertw(z::T, x::T, maxits) where T <: Number + two_t = convert(T,2) + lastx = x + lastdiff = zero(T) + converged::Bool = false + for i in 1:maxits + ex = exp(x) + xexz = x * ex - z + x1 = x + 1 + x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) ) + xdiff = abs(lastx - x) + if xdiff <= 3*eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle + converged = true + break + end + lastx = x + lastdiff = xdiff + end + converged || warn("lambertw with z=", z, " did not converge in ", maxits, " iterations.") + return x +end + +### Real z ### + +# Real x, k = 0 +# This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf. +# The fancy initial condition selection does not seem to help speed, but we leave it for now. +function lambertwk0(x::T, maxits)::T where T<:AbstractFloat + isnan(x) && return(NaN) + x == Inf && return Inf # appears to return convert(BigFloat,Inf) for x == BigFloat(Inf) + one_t = one(T) + oneoe = -one_t/convert(T,MathConstants.e) # The branch point + x == oneoe && return -one_t + oneoe <= x || throw(DomainError(x)) + itwo_t = 1/convert(T,2) + if x > one_t + lx = log(x) + llx = log(lx) + x1 = lx - llx - log(one_t - llx/lx) * itwo_t + else + x1 = (567//1000) * x + end + return _lambertw(x, x1, maxits) +end + +# Real x, k = -1 +function lambertwkm1(x::T, maxits) where T<:Real + oneoe = -one(T)/convert(T,MathConstants.e) + x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above + oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e + x == zero(T) && return -convert(T,Inf) # W decreases w/o bound as x -> 0 from below + x < zero(T) || throw(DomainError(x)) + return _lambertw(x, log(-x), maxits) +end + +""" + lambertw(z::Complex{T}, k::V=0, maxits=1000) where {T<:Real, V<:Integer} + lambertw(z::T, k::V=0, maxits=1000) where {T<:Real, V<:Integer} + +Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be +either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the +domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is +the complex plane. When using root finding to compute `W`, a value for `W` is returned +with a warning if it has not converged after `maxits` iterations. + +```jldoctest +julia> lambertw(-1/e,-1) +-1.0 + +julia> lambertw(-1/e,0) +-1.0 + +julia> lambertw(0,0) +0.0 + +julia> lambertw(0,-1) +-Inf + +julia> lambertw(Complex(-10.0,3.0), 4) +-0.9274337508660128 + 26.37693445371142im +``` + +""" +lambertw(z, k::Integer=0, maxits::Integer=1000) = lambertw_(z, k, maxits) + +function lambertw_(x::Real, k, maxits) + k == 0 && return lambertwk0(x, maxits) + k == -1 && return lambertwkm1(x, maxits) + throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1")) +end + +function lambertw_(x::Union{Integer,Rational}, k, maxits) + if k == 0 + x == 0 && return float(zero(x)) + x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way + end + return lambertw_(float(x), k, maxits) +end + +### Complex z ### + +# choose initial value inside correct branch for root finding +function lambertw_(z::Complex{T}, k, maxits) where T<:Real + one_t = one(T) + local w::Complex{T} + pointseven = 7//10 + if abs(z) <= one_t/convert(T,MathConstants.e) + if z == 0 + k == 0 && return z + return complex(-convert(T,Inf),zero(T)) + end + if k == 0 + w = z + elseif k == -1 && imag(z) == 0 && real(z) < 0 + w = complex(log(-real(z)),1//10^7) # need offset for z ≈ -1/e. + else + w = log(z) + k != 0 ? w += complex(0,k * 2 * pi) : nothing + end + elseif k == 0 && imag(z) <= pointseven && abs(z) <= pointseven + w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? complex(pointseven,pointseven) : complex(pointseven,-pointseven) : z + else + if real(z) == convert(T,Inf) + k == 0 && return z + return z + complex(0,2*k*pi) + end + real(z) == -convert(T,Inf) && return -z + complex(0,(2*k+1)*pi) + w = log(z) + k != 0 ? w += complex(0, 2*k*pi) : nothing + end + return _lambertw(z, w, maxits) +end + +lambertw_(z::Complex{T}, k, maxits) where T<:Integer = lambertw_(float(z), k, maxits) +lambertw_(n::Irrational, k, maxits) = lambertw_(float(n), k, maxits) + +# lambertw(e + 0im,k) is ok for all k +# Maybe this should return a float. But, this should cause no type instability in any case +function lambertw_(::typeof(MathConstants.e), k, maxits) + k == 0 && return 1 + throw(DomainError(k)) +end + +### omega constant ### + +const omega_const_ = 0.567143290409783872999968662210355 +# The BigFloat `omega_const_bf_` is set via a literal in the function __init__ to prevent a segfault + +# maybe compute higher precision. converges very quickly +function omega_const(::Type{BigFloat}) + precision(BigFloat) <= 256 && return omega_const_bf_[] + myeps = eps(BigFloat) + oc = omega_const_bf_[] + for i in 1:100 + nextoc = (1 + oc) / (1 + exp(oc)) + abs(oc - nextoc) <= myeps && break + oc = nextoc + end + return oc +end + +""" + omega + ω + +The constant defined by `ω exp(ω) = 1`. + +```jldoctest +julia> ω +ω = 0.5671432904097... + +julia> omega +ω = 0.5671432904097... + +julia> ω * exp(ω) +1.0 + +julia> big(omega) +5.67143290409783872999968662210355549753815787186512508135131079223045793086683e-01 +``` +""" +const ω = Irrational{:ω}() +@doc (@doc ω) omega = ω + +# The following two lines may be removed when support for v0.6 is dropped +Base.convert(::Type{AbstractFloat}, o::Irrational{:ω}) = Float64(o) +Base.convert(::Type{Float16}, o::Irrational{:ω}) = Float16(o) +Base.convert(::Type{T}, o::Irrational{:ω}) where T <:Number = T(o) + +Base.Float64(::Irrational{:ω}) = omega_const_ # FIXME: This is very slow. Why ? +Base.Float32(::Irrational{:ω}) = Float32(omega_const_) +Base.Float16(::Irrational{:ω}) = Float16(omega_const_) +Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat) + +### Expansion about branch point x = -1/e ### + +# Refer to the paper "On the Lambert W function". In (4.22) +# coefficients μ₀ through μ₃ are given explicitly. Recursion relations +# (4.23) and (4.24) for all μ are also given. This code implements the +# recursion relations. + +# (4.23) and (4.24) give zero based coefficients. +cset(a,i,v) = a[i+1] = v +cget(a,i) = a[i+1] + +# (4.24) +function compa(k,m,a) + sum0 = zero(eltype(m)) + for j in 2:k-1 + sum0 += cget(m,j) * cget(m,k+1-j) + end + cset(a,k,sum0) + return sum0 +end + +# (4.23) +function compm(k,m,a) + kt = convert(eltype(m),k) + mk = (kt-1)/(kt+1) *(cget(m,k-2)/2 + cget(a,k-2)/4) - + cget(a,k)/2 - cget(m,k-1)/(kt+1) + cset(m,k,mk) + return mk +end + +# We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and +# solve for α₂. We get α₂ = 0. +# compute array of coefficients μ in (4.22). +# m[1] is μ₀ +function lamwcoeff(T::DataType, n::Int) + # a = @compat Array{T}(undef,n) + # m = @compat Array{T}(undef,n) + a = zeros(T,n) # We don't need initialization, but Compat is a huge PITA. + m = zeros(T,n) + cset(a,0,2) # α₀ literal in paper + cset(a,1,-1) # α₁ literal in paper + cset(a,2,0) # α₂ get this by solving (4.23) for alpha_2 with values printed in paper + cset(m,0,-1) # μ₀ literal in paper + cset(m,1,1) # μ₁ literal in paper + cset(m,2,-1//3) # μ₂ literal in paper, but only in (4.22) + for i in 3:n-1 # coeffs are zero indexed + compa(i,m,a) + compm(i,m,a) + end + return m +end + +const LAMWMU_FLOAT64 = lamwcoeff(Float64,500) + +# Base.Math.@horner requires literal coefficients +# But, we have an array `p` of computed coefficients +function horner(x, p::AbstractArray, n) + n += 1 + ex = p[n] + for i = n-1:-1:2 + ex = :(muladd(t, $ex, $(p[i]))) + end + ex = :( t * $ex) + return Expr(:block, :(t = $x), ex) +end + +function mkwser(name, n) + iex = horner(:x,LAMWMU_FLOAT64,n) + return :(function ($name)(x) $iex end) +end + +eval(mkwser(:wser3, 3)) +eval(mkwser(:wser5, 5)) +eval(mkwser(:wser7, 7)) +eval(mkwser(:wser12, 12)) +eval(mkwser(:wser19, 19)) +eval(mkwser(:wser26, 26)) +eval(mkwser(:wser32, 32)) +eval(mkwser(:wser50, 50)) +eval(mkwser(:wser100, 100)) +eval(mkwser(:wser290, 290)) + +# Converges to Float64 precision +# We could get finer tuning by separating k=0,-1 branches. +function wser(p,x) + x < 4e-11 && return wser3(p) + x < 1e-5 && return wser7(p) + x < 1e-3 && return wser12(p) + x < 1e-2 && return wser19(p) + x < 3e-2 && return wser26(p) + x < 5e-2 && return wser32(p) + x < 1e-1 && return wser50(p) + x < 1.9e-1 && return wser100(p) + x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence + return wser290(p) # good for x approx .32 +end + +# These may need tuning. +function wser(p::Complex{T},z) where T<:Real + x = abs(z) + x < 4e-11 && return wser3(p) + x < 1e-5 && return wser7(p) + x < 1e-3 && return wser12(p) + x < 1e-2 && return wser19(p) + x < 3e-2 && return wser26(p) + x < 5e-2 && return wser32(p) + x < 1e-1 && return wser50(p) + x < 1.9e-1 && return wser100(p) + x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence + return wser290(p) +end + +@inline function _lambertw0(x) # 1 + W(-1/e + x) , k = 0 + ps = 2*MathConstants.e*x; + p = sqrt(ps) + return wser(p,x) +end + +@inline function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1 + ps = 2*MathConstants.e*x; + p = -sqrt(ps) + return wser(p,x) +end + +""" + lambertwbp(z,k=0) + +Compute accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`. +The result is accurate to Float64 precision for abs(z) < 0.32. +If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. + +```jldoctest +julia> lambertw(-1/e + 1e-18, -1) +-1.0 + +julia> lambertwbp(1e-18, -1) +-2.331643983409312e-9 + +# Same result, but 1000 times slower +julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1)) +-2.331643983409312e-9 +``` + +!!! note + `lambertwbp` uses a series expansion about the branch point `z=-1/e` to avoid loss of precision. + The loss of precision in `lambertw` is analogous to the loss of precision + in computing the `sqrt(1-x)` for `x` close to `1`. +""" +function lambertwbp(x::Number,k::Integer) + k == 0 && return _lambertw0(x) + k == -1 && return _lambertwm1(x) + throw(ArgumentError("expansion about branch point only implemented for k = 0 and -1.")) +end + +lambertwbp(x::Number) = _lambertw0(x) diff --git a/test/lambertw_test.jl b/test/lambertw_test.jl new file mode 100644 index 00000000..46d5dd57 --- /dev/null +++ b/test/lambertw_test.jl @@ -0,0 +1,156 @@ +using Compat + +import Compat.MathConstants + +### domain errors + +@test_throws DomainError lambertw(-2.0,0) +@test_throws DomainError lambertw(-2.0,-1) +@test_throws DomainError lambertw(-2.0,1) +@test isnan(lambertw(NaN)) + +## math constant e +@test_throws DomainError lambertw(MathConstants.e,1) +@test_throws DomainError lambertw(MathConstants.e,-1) + +## integer arguments return floating point types +@test typeof(lambertw(0)) <: AbstractFloat +@test lambertw(0) == 0 + +### math constant, MathConstants.e e + +# could return math const e, but this would break type stability +@test typeof(lambertw(1)) <: AbstractFloat +@test lambertw(MathConstants.e,0) == 1 + +## value at branch point where real branches meet +@test lambertw(-1/MathConstants.e,0) == lambertw(-1/MathConstants.e,-1) == -1 +@test typeof(lambertw(-1/MathConstants.e,0)) == typeof(lambertw(-1/MathConstants.e,-1)) <: AbstractFloat + +## convert irrationals to float + +@test isapprox(lambertw(pi), 1.0736581947961492) +@test isapprox(lambertw(pi,0), 1.0736581947961492) + +### infinite args or return values + +@test lambertw(0,-1) == lambertw(0.0,-1) == -Inf +@test lambertw(Inf,0) == Inf +@test lambertw(complex(Inf,1),0) == complex(Inf,1) +@test lambertw(complex(Inf,0),1) == complex(Inf,2pi) +@test lambertw(complex(-Inf,0),1) == complex(Inf,3pi) +@test lambertw(complex(0.0,0.0),-1) == complex(-Inf,0.0) + +## default branch is k = 0 +@test lambertw(1.0) == lambertw(1.0,0) + +## BigInt args return BigFloats +@test typeof(lambertw(BigInt(0))) == BigFloat +@test typeof(lambertw(BigInt(3))) == BigFloat + +## Any Integer type allowed for second argument +@test lambertw(-0.2,-1) == lambertw(-0.2,BigInt(-1)) + +## BigInt for second arg does not promote the type +@test typeof(lambertw(-0.2,-1)) == typeof(lambertw(-0.2,BigInt(-1))) + +for (z,k,res) in [ (0,0 ,0), (complex(0,0),0 ,0), + (complex(0.0,0),0 ,0), (complex(1.0,0),0, 0.567143290409783873) ] + if Int != Int32 + @test isapprox(lambertw(z,k), res) + @test isapprox(lambertw(z), res) + else + @test isapprox(lambertw(z,k), res; rtol = 1e-14) + @test isapprox(lambertw(z), res; rtol = 1e-14) + end +end + +for (z,k) in ((complex(1,1),2), (complex(1,1),0),(complex(.6,.6),0), + (complex(.6,-.6),0)) + let w + @test (w = lambertw(z,k) ; true) + @test abs(w*exp(w) - z) < 1e-15 + end +end + +@test abs(lambertw(complex(-3.0,-4.0),0) - Complex(1.075073066569255, -1.3251023817343588)) < 1e-14 +@test abs(lambertw(complex(-3.0,-4.0),1) - Complex(0.5887666813694675, 2.7118802109452247)) < 1e-14 +@test (lambertw(complex(.3,.3),0); true) + +# bug fix +# The routine will start at -1/e + eps * im, rather than -1/e + 0im, +# otherwise root finding will fail +if Int != Int32 + @test abs(lambertw(-1.0/MathConstants.e + 0im,-1)) == 1 +else + @test abs(lambertw(-1.0/MathConstants.e + 0im,-1) + 1) < 1e-7 +end +# lambertw for BigFloat is more precise than Float64. Note +# that 70 digits in test is about 35 digits in W +let W + for z in [ BigFloat(1), BigFloat(2), complex(BigFloat(1), BigFloat(1))] + @test (W = lambertw(z); true) + @test abs(z - W * exp(W)) < BigFloat(1)^(-70) + end +end + +### ω constant + +## get ω from recursion and compare to value from lambertw +let sp = precision(BigFloat) + setprecision(512) + @test lambertw(big(1)) == big(SpecialFunctions.omega) + setprecision(sp) +end + +@test lambertw(1) == float(SpecialFunctions.omega) +@test convert(Float16,SpecialFunctions.omega) == convert(Float16,0.5674) +@test convert(Float32,SpecialFunctions.omega) == 0.56714326f0 +@test lambertw(BigInt(1)) == big(SpecialFunctions.omega) + +### expansion about branch point + +# not a domain error, but not implemented +@test_throws ArgumentError lambertwbp(1,1) + +@test_throws DomainError lambertw(.3,2) + +# Expansions about branch point converges almost to machine precision +# except near the radius of convergence. +# Complex args are not tested here. + +if Int != Int32 + +# Test double-precision expansion near branch point using BigFloats +let sp = precision(BigFloat), z = BigFloat(1)/10^12, wo, xdiff + setprecision(2048) + for i in 1:300 + innerarg = z-1/big(MathConstants.e) + + # branch k = 0 + @test (wo = lambertwbp(Float64(z)); xdiff = abs(-1 + wo - lambertw(innerarg)); true) + if xdiff > 5e-16 + println(Float64(z), " ", Float64(xdiff)) + end + @test xdiff < 5e-16 + + # branch k = -1 + @test (wo = lambertwbp(Float64(z),-1); xdiff = abs(-1 + wo - lambertw(innerarg,-1)); true) + if xdiff > 5e-16 + println(Float64(z), " ", Float64(xdiff)) + end + @test xdiff < 5e-16 + z *= 1.1 + if z > 0.23 break end + + end + setprecision(sp) +end + +# test the expansion about branch point for k=-1, +# by comparing to exact BigFloat calculation. +@test lambertwbp(1e-20,-1) - 1 - lambertw(-BigFloat(1)/big(MathConstants.e)+ BigFloat(1)/BigFloat(10)^BigFloat(20),-1) < 1e-16 + +@test abs(lambertwbp(Complex(.01,.01),-1) - Complex(-0.2755038208041206, -0.1277888928494641)) < 1e-14 + +end # if Int != Int32 diff --git a/test/runtests.jl b/test/runtests.jl index c82a4cfa..4501b256 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,10 @@ relerr(z, x) = z == x ? 0.0 : abs(z - x) / abs(x) relerrc(z, x) = max(relerr(real(z),real(x)), relerr(imag(z),imag(x))) ≅(a,b) = relerrc(a,b) ≤ 1e-13 +@testset "Lambert W" begin + include("lambertw_test.jl") +end + @testset "error functions" begin @test SF.erf(Float16(1)) ≈ 0.84270079294971486934 @test SF.erf(1) ≈ 0.84270079294971486934