Skip to content

Commit 781c83a

Browse files
committed
force specialization of dims argument type where it may be Colon
As discussed in the Performance Tips in the Manual, Julia avoids specializing calls on the type of `Function` arguments by default. The function `:`, `(:) isa Function`, is often used not as a callable, but as the special value for specifying the "full range" or "all dimensions". However in such cases we often forget to force specialization on the type of `:`. This change fixes that. See PR JuliaLang/julia#59474, which applies the same kind of change in Julia itself. NB: I don't have an example where this change helps for StaticArrays, however the eliminated invalidation in the linked JuliaLang/julia PR is proof that it does help in some cases. Finding such examples is difficult because the compiler is often able to achieve good results because of constant propagation. However constprop is often fragile, so it is better to avoid relying on it. For example, constprop through recursion is not even attempted by the Julia compiler. I believe this change should not cause any real-world compile time regression, as `:` is the only function that is valid as a dims argument.
1 parent bb20cf4 commit 781c83a

File tree

2 files changed

+28
-28
lines changed

2 files changed

+28
-28
lines changed

ext/StaticArraysStatisticsExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ _mean_denom(a, ::Colon) = length(a)
1212
_mean_denom(a, dims::Int) = size(a, dims)
1313
_mean_denom(a, ::Val{D}) where {D} = size(a, D)
1414

15-
@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
16-
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims)
15+
@inline mean(a::StaticArray; dims::D=:) where {D} = _reduce(+, a, dims) / _mean_denom(a, dims)
16+
@inline mean(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims)
1717

18-
@inline function median(a::StaticArray; dims = :)
18+
@inline function median(a::StaticArray; dims::D = :) where {D}
1919
if dims == Colon()
2020
median(vec(a))
2121
else

src/mapreduce.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ end
146146
## mapreduce ##
147147
###############
148148

149-
@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:, init = _InitialValue())
149+
@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims::D=:, init = _InitialValue()) where {D}
150150
_mapreduce(f, op, dims, init, same_size(a, b...), a, b...)
151151
end
152152

@@ -235,7 +235,7 @@ end
235235
## reduce ##
236236
############
237237

238-
@inline reduce(op::R, a::StaticArray; dims = :, init = _InitialValue()) where {R} =
238+
@inline reduce(op::R, a::StaticArray; dims::D = :, init = _InitialValue()) where {D, R} =
239239
_reduce(op, a, dims, init)
240240

241241
# disambiguation
@@ -249,7 +249,7 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) =
249249
reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
250250
_reduce(hcat, A, :, _InitialValue())
251251

252-
@inline _reduce(op::R, a::StaticArray, dims, init = _InitialValue()) where {R} =
252+
@inline _reduce(op::R, a::StaticArray, dims::D, init = _InitialValue()) where {D, R} =
253253
_mapreduce(identity, op, dims, init, Size(a), a)
254254

255255
################
@@ -264,7 +264,7 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
264264
_mapfoldl(f, op, :, init, Size(a), a)
265265
@inline foldl(op::R, a::StaticArray; init = _InitialValue()) where {R} =
266266
_foldl(op, a, :, init)
267-
@inline _foldl(op::R, a, dims, init = _InitialValue()) where {R} =
267+
@inline _foldl(op::R, a, dims::D, init = _InitialValue()) where {D, R} =
268268
_mapfoldl(identity, op, dims, init, Size(a), a)
269269

270270
#######################
@@ -290,33 +290,33 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
290290
# TODO: change to use Base.reduce_empty/Base.reduce_first
291291
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)
292292

293-
@inline sum(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(+, a, dims, init)
294-
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a)
295-
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a) # avoid ambiguity
293+
@inline sum(a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _reduce(+, a, dims, init)
294+
@inline sum(f, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, +, dims, init, Size(a), a)
295+
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, +, dims, init, Size(a), a) # avoid ambiguity
296296

297-
@inline prod(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(*, a, dims, init)
298-
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a)
299-
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a)
297+
@inline prod(a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _reduce(*, a, dims, init)
298+
@inline prod(f, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, *, dims, init, Size(a), a)
299+
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, *, dims, init, Size(a), a)
300300

301-
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:, init=0) = _reduce(+, a, dims, init)
302-
@inline count(f, a::StaticArray; dims=:, init=0) = _mapreduce(x->f(x)::Bool, +, dims, init, Size(a), a)
301+
@inline count(a::StaticArray{<:Tuple,Bool}; dims::D=:, init=0) where {D} = _reduce(+, a, dims, init)
302+
@inline count(f, a::StaticArray; dims::D=:, init=0) where {D} = _mapreduce(x->f(x)::Bool, +, dims, init, Size(a), a)
303303

304-
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, true) # non-branching versions
305-
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a)
304+
@inline all(a::StaticArray{<:Tuple,Bool}; dims::D=:) where {D} = _reduce(&, a, dims, true) # non-branching versions
305+
@inline all(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a)
306306

307-
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, false) # (benchmarking needed)
308-
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed)
307+
@inline any(a::StaticArray{<:Tuple,Bool}; dims::D=:) where {D} = _reduce(|, a, dims, false) # (benchmarking needed)
308+
@inline any(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed)
309309

310310
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, false, Size(a), a)
311311

312-
@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(identity, scalarmin, a)
313-
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, _InitialValue(), Size(a), a)
312+
@inline minimum(a::StaticArray; dims::D=:) where {D} = _reduce(min, a, dims) # base has mapreduce(identity, scalarmin, a)
313+
@inline minimum(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, min, dims, _InitialValue(), Size(a), a)
314314

315-
@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(identity, scalarmax, a)
316-
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, _InitialValue(), Size(a), a)
315+
@inline maximum(a::StaticArray; dims::D=:) where {D} = _reduce(max, a, dims) # base has mapreduce(identity, scalarmax, a)
316+
@inline maximum(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, max, dims, _InitialValue(), Size(a), a)
317317

318318
# Diff is slightly different
319-
@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims)
319+
@inline diff(a::StaticArray; dims::D) where {D} = _diff(Size(a), a, dims)
320320
@inline diff(a::StaticVector) = diff(a;dims=Val(1))
321321

322322
@inline function _diff(sz::Size{S}, a::StaticArray, D::Int) where {S}
@@ -343,16 +343,16 @@ end
343343
end
344344

345345
_maybe_val(dims::Integer) = Val(Int(dims))
346-
_maybe_val(dims) = dims
346+
_maybe_val(dims::D) where {D} = dims
347347
_valof(::Val{D}) where D = D
348348

349-
@inline Base.accumulate(op::F, a::StaticVector; dims = :, init = _InitialValue()) where {F} =
349+
@inline Base.accumulate(op::F, a::StaticVector; dims::D = :, init = _InitialValue()) where {D, F} =
350350
_accumulate(op, a, _maybe_val(dims), init)
351351

352-
@inline Base.accumulate(op::F, a::StaticArray; dims, init = _InitialValue()) where {F} =
352+
@inline Base.accumulate(op::F, a::StaticArray; dims::D, init = _InitialValue()) where {D, F} =
353353
_accumulate(op, a, _maybe_val(dims), init)
354354

355-
@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F}
355+
@inline function _accumulate(op::F, a::StaticArray, dims::Dimensions, init) where {Dimensions <: Union{Val,Colon}, F}
356356
# Adjoin the initial value to `op` (one-line version of `Base.BottomRF`):
357357
rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y)
358358

0 commit comments

Comments
 (0)