Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.44.7"
version = "1.45.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
6 changes: 6 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ import ChainRulesCore: rrule, frule
# Experimental:
using ChainRulesCore: derivatives_given_output

if isdefined(Base, :stack)
using Base: stack
else
using Compat: stack
end

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}

Expand Down
13 changes: 13 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,16 @@ function _extrema_dims(x, dims)
end
return y, extrema_pullback_dims
end

#####
##### `stack`
#####

function rrule(::typeof(stack), xs; dims::Union{Integer, Colon} = :)
dims = dims === Colon() ? ndims(first(xs)) + 1 : dims
function stack_pullback(Δ)
dy = unthunk(Δ)
return (NoTangent(), [copy(selectdim(dy, dims, i)) for i in 1:size(dy, dims)])
end
return stack(xs; dims), stack_pullback
end
9 changes: 9 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,12 @@ end
B = hcat(A[:,:,1], A[:,:,1])
@test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1]
end

@testset "stack" begin
xs = [rand(3, 4), rand(3, 4)]

test_rrule(stack, xs, check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false)
end