We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 415ec0a commit cf7f7d0Copy full SHA for cf7f7d0
Project.toml
@@ -1,6 +1,6 @@
1
name = "Zygote"
2
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3
-version = "0.6.64"
+version = "0.6.65"
4
5
[deps]
6
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
src/lib/array.jl
@@ -368,12 +368,10 @@ function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
368
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
369
end
370
_kron(a::AbstractVector, b::AbstractVector) = vec(_kron(reshape(a, :, 1), reshape(b, :, 1)))
371
+_kron(a::AbstractVector, b::AbstractMatrix) = _kron(reshape(a, :, 1), b)
372
+_kron(a::AbstractMatrix, b::AbstractVector) = _kron(a, reshape(b, :, 1))
373
-function _pullback(cx::AContext, ::typeof(kron), a::AbstractVector, b::AbstractVector)
- res, back = _pullback(cx, _kron, a, b)
374
- return res, back ∘ unthunk_tangent
375
-end
376
-function _pullback(cx::AContext, ::typeof(kron), a::AbstractMatrix, b::AbstractMatrix)
+function _pullback(cx::AContext, ::typeof(kron), a::AbstractVecOrMat, b::AbstractVecOrMat)
377
res, back = _pullback(cx, _kron, a, b)
378
return res, back ∘ unthunk_tangent
379
test/gradcheck.jl
@@ -275,6 +275,8 @@ end
275
@test gradtest(kron, rand(5,1), rand(3,1))
276
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
277
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
278
+@test gradtest(kron, rand(5), rand(3, 2))
279
+@test gradtest(kron, rand(3, 2), rand(5))
280
281
for mapfunc in [map,pmap]
282
@testset "$mapfunc" begin
0 commit comments