Skip to content

_map -> Base.map #453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.37"
version = "0.10.38"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -19,7 +19,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "1"
Expand All @@ -34,5 +33,4 @@ Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.32, 0.33"
TensorCore = "0.1"
ZygoteRules = "0.2"
julia = "1.3"
2 changes: 0 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
using LogExpFunctions: softplus
using StatsBase
using TensorCore
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield

# Hack to work around Zygote type inference problems.
const Distances_pairwise = Distances.pairwise
Expand Down Expand Up @@ -122,7 +121,6 @@ include("mokernels/intrinsiccoregion.jl")
include("mokernels/lmm.jl")

include("chainrules.jl")
include("zygoterules.jl")

include("test_utils.jl")

Expand Down
16 changes: 8 additions & 8 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,37 +82,37 @@
# Kernel matrix operations

function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x))
return kernelmatrix_diag!(K, κ.kernel, map(κ.transform, x))

Check warning on line 85 in src/kernels/transformedkernel.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/transformedkernel.jl#L85

Added line #L85 was not covered by tests
end

function kernelmatrix_diag!(
K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix_diag!(K, κ.kernel, map(κ.transform, x), map(κ.transform, y))

Check warning on line 91 in src/kernels/transformedkernel.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/transformedkernel.jl#L91

Added line #L91 was not covered by tests
end

function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x))
return kernelmatrix!(K, κ.kernel, map(κ.transform, x))

Check warning on line 95 in src/kernels/transformedkernel.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/transformedkernel.jl#L95

Added line #L95 was not covered by tests
end

function kernelmatrix!(
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix!(K, κ.kernel, map(κ.transform, x), map(κ.transform, y))

Check warning on line 101 in src/kernels/transformedkernel.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/transformedkernel.jl#L101

Added line #L101 was not covered by tests
end

function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x))
return kernelmatrix_diag(κ.kernel, map(κ.transform, x))

Check warning on line 105 in src/kernels/transformedkernel.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/transformedkernel.jl#L105

Added line #L105 was not covered by tests
end

function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix_diag(κ.kernel, map(κ.transform, x), map(κ.transform, y))

Check warning on line 109 in src/kernels/transformedkernel.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/transformedkernel.jl#L109

Added line #L109 was not covered by tests
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix(κ.kernel, _map(κ.transform, x))
return kernelmatrix(κ.kernel, map(κ.transform, x))

Check warning on line 113 in src/kernels/transformedkernel.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/transformedkernel.jl#L113

Added line #L113 was not covered by tests
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix(κ.kernel, map(κ.transform, x), map(κ.transform, y))

Check warning on line 117 in src/kernels/transformedkernel.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/transformedkernel.jl#L117

Added line #L117 was not covered by tests
end
6 changes: 3 additions & 3 deletions src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
(t::ARDTransform)(x::Real) = only(t.v) * x
(t::ARDTransform)(x) = t.v .* x

_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x

Check warning on line 38 in src/transform/ardtransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/ardtransform.jl#L38

Added line #L38 was not covered by tests
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)

Check warning on line 40 in src/transform/ardtransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/ardtransform.jl#L40

Added line #L40 was not covered by tests

Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)

Expand Down
2 changes: 1 addition & 1 deletion src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor

(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

function _map(t::ChainTransform, x::AbstractVector)
function Base.map(t::ChainTransform, x::AbstractVector)
return foldl((x, t) -> map(t, x), t.transforms; init=x)
end

Expand Down
6 changes: 3 additions & 3 deletions src/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@

(t::FunctionTransform)(x) = t.f(x)

_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)

Check warning on line 26 in src/transform/functiontransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/functiontransform.jl#L26

Added line #L26 was not covered by tests

function _map(t::FunctionTransform, x::ColVecs)
function Base.map(t::FunctionTransform, x::ColVecs)
vals = map(axes(x.X, 2)) do i
t.f(view(x.X, :, i))
end
return ColVecs(reduce(hcat, vals))
end

function _map(t::FunctionTransform, x::RowVecs)
function Base.map(t::FunctionTransform, x::RowVecs)

Check warning on line 35 in src/transform/functiontransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/functiontransform.jl#L35

Added line #L35 was not covered by tests
vals = map(axes(x.X, 1)) do i
t.f(view(x.X, i, :))
end
Expand Down
6 changes: 3 additions & 3 deletions src/transform/lineartransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
(t::LinearTransform)(x::Real) = vec(t.A * x)
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x

_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))

Check warning on line 37 in src/transform/lineartransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/lineartransform.jl#L37

Added line #L37 was not covered by tests
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')

Check warning on line 39 in src/transform/lineartransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/lineartransform.jl#L39

Added line #L39 was not covered by tests

function Base.show(io::IO, t::LinearTransform)
return print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")
Expand Down
2 changes: 1 addition & 1 deletion src/transform/periodic_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)]

function _map(t::PeriodicTransform, x::AbstractVector{<:Real})
function Base.map(t::PeriodicTransform, x::AbstractVector{<:Real})

Check warning on line 30 in src/transform/periodic_transform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/periodic_transform.jl#L30

Added line #L30 was not covered by tests
return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x)))
end

Expand Down
6 changes: 3 additions & 3 deletions src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

(t::ScaleTransform)(x) = only(t.s) * x

_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)
Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x

Check warning on line 29 in src/transform/scaletransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/scaletransform.jl#L29

Added line #L29 was not covered by tests
Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)

Check warning on line 31 in src/transform/scaletransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/scaletransform.jl#L31

Added line #L31 was not covered by tests

Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))

Expand Down
4 changes: 2 additions & 2 deletions src/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
_maybe_unwrap(x) = x
_maybe_unwrap(x::AbstractArray{<:Any,0}) = x[]

_map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
_map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)
Base.map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
Base.map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)

Check warning on line 29 in src/transform/selecttransform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/selecttransform.jl#L29

Added line #L29 was not covered by tests

_wrap(x::AbstractVector{<:Real}, ::Any) = x
_wrap(X::AbstractMatrix{<:Real}, ::Type{T}) where {T} = T(X)
Expand Down
5 changes: 2 additions & 3 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
"""
abstract type Transform end

Base.map(t::Transform, x::AbstractVector) = _map(t, x)
_map(t::Transform, x::AbstractVector) = t.(x)
Base.map(t::Transform, x::AbstractVector) = t.(x)

Check warning on line 8 in src/transform/transform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/transform.jl#L8

Added line #L8 was not covered by tests

"""
IdentityTransform()
Expand All @@ -16,7 +15,7 @@
struct IdentityTransform <: Transform end

(t::IdentityTransform)(x) = x
_map(::IdentityTransform, x::AbstractVector) = x
Base.map(::IdentityTransform, x::AbstractVector) = x

Check warning on line 18 in src/transform/transform.jl

View check run for this annotation

Codecov / codecov/patch

src/transform/transform.jl#L18

Added line #L18 was not covered by tests

### TODO Maybe defining adjoints could help but so far it's not working

Expand Down
13 changes: 0 additions & 13 deletions src/zygoterules.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ include("test_utils.jl")

include("generic.jl")
include("chainrules.jl")
include("zygoterules.jl")

@testset "doctests" begin
DocMeta.setdocmeta!(
Expand Down
2 changes: 1 addition & 1 deletion test/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
("ColVecs", ColVecs(randn(5, 10))),
("RowVecs", RowVecs(randn(11, 4))),
]
@test KernelFunctions._map(t, x) isa AbstractVector{Float64}
@test map(t, x) isa AbstractVector{Float64}
end
end
end
1 change: 0 additions & 1 deletion test/zygoterules.jl

This file was deleted.