Skip to content

Commit daefe40

Browse files
authored
Specialize _matmul! on ForwardDiff.Duals (#117)
* _matmul! with matrix multipliying array of ForwardDiff.Dual numbers * _matmul! with matrix of ForwardDiff.Dual numbers multiplying array * fall back to standard method if possible * add compat bound for Requires * increase test coverage * set version to v0.3.5
1 parent 9f56c5c commit daefe40

File tree

7 files changed

+122
-5
lines changed

7 files changed

+122
-5
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "Octavian"
22
uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
33
authors = ["Mason Protter", "Chris Elrod", "Dilum Aluthge", "contributors"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
99
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1010
ManualMemory = "d125e4d3-2237-4719-b19c-fa641b8a4667"
1111
PolyesterWeave = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad"
12+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1213
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1314
ThreadingUtilities = "8290d209-cae3-49c0-8002-c8c24d57dab5"
1415
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
@@ -19,6 +20,7 @@ IfElse = "0.1"
1920
LoopVectorization = "0.12.86"
2021
ManualMemory = "0.1.1"
2122
PolyesterWeave = "0.1.1"
23+
Requires = "1"
2224
Static = "0.2, 0.3"
2325
ThreadingUtilities = "0.4.6"
2426
VectorizationBase = "0.21.15"
@@ -27,6 +29,7 @@ julia = "1.6"
2729
[extras]
2830
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2931
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
32+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3033
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
3134
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3235
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
@@ -35,4 +38,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3538
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
3639

3740
[targets]
38-
test = ["Aqua", "BenchmarkTools", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "Random", "VectorizationBase", "Test"]
41+
test = ["Aqua", "BenchmarkTools", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "Random", "VectorizationBase", "Test"]

src/Octavian.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module Octavian
22

3+
using Requires: @require
4+
35
using VectorizationBase, ArrayInterface, LoopVectorization
46

57
using VectorizationBase: align, AbstractStridedPointer, zstridedpointer, vsub_nsw, assume,
68
static_sizeof, StridedPointer, gesp, pause, pick_vector_width, has_feature,
7-
cache_size, num_cores, num_cores, cache_inclusive, cache_linesize
9+
cache_size, num_cores, cache_inclusive, cache_linesize
810
using LoopVectorization: preserve_buffer, CloseOpen, UpperBoundedInteger
911
using ArrayInterface: size, strides, offsets, indices, axes, StrideIndex
1012
using IfElse: ifelse

src/forward_diff.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
real_rep(a::AbstractArray{DualT}) where {TAG, T, DualT<:ForwardDiff.Dual{TAG, T}} = reinterpret(reshape, T, a)
3+
4+
# multiplication of dual vector/matrix by standard matrix from the left
5+
@inline function _matmul!(_C::AbstractVecOrMat{DualT}, A::AbstractMatrix, _B::AbstractVecOrMat{DualT},
6+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {DualT <: ForwardDiff.Dual}
7+
B = real_rep(_B)
8+
C = real_rep(_C)
9+
10+
@tturbo for n indices((C, B), 3), m indices((C, A), (2, 1)), l in indices((C, B), 1)
11+
Cₗₘₙ = zero(eltype(C))
12+
for k indices((A, B), 2)
13+
Cₗₘₙ += A[m, k] * B[l, k, n]
14+
end
15+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
16+
end
17+
18+
_C
19+
end
20+
21+
# multiplication of dual matrix by standard vector/matrix from the right
22+
@inline function _matmul!(_C::AbstractVecOrMat{DualT}, _A::AbstractMatrix{DualT}, B::AbstractVecOrMat,
23+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing) where {TAG, T, DualT <: ForwardDiff.Dual{TAG, T}}
24+
if all((ArrayInterface.is_dense(_C), ArrayInterface.is_column_major(_C),
25+
ArrayInterface.is_dense(_A), ArrayInterface.is_column_major(_A)))
26+
# we can avoid the reshape and call the standard method
27+
A = reinterpret(T, _A)
28+
C = reinterpret(T, _C)
29+
_matmul!(C, A, B, α, β, nthread, MKN)
30+
else
31+
# we cannot use the standard method directly
32+
A = real_rep(_A)
33+
C = real_rep(_C)
34+
35+
@tturbo for n indices((C, B), (3, 2)), m indices((C, A), 2), l in indices((C, A), 1)
36+
Cₗₘₙ = zero(eltype(C))
37+
for k indices((A, B), (3, 1))
38+
Cₗₘₙ += A[l, m, k] * B[k, n]
39+
end
40+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
41+
end
42+
end
43+
44+
_C
45+
end

src/init.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
function __init__()
2+
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forward_diff.jl")
3+
24
init_acache()
35
init_bcache()
46
nt = init_num_tasks()

src/macrokernels.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ end
109109
for k CloseOpen(K)
110110
Aₘₖ = A[m,k]
111111
Cₘₙ += Aₘₖ * B[k,n]
112-
Ãₚ[m,k] = Aₘₖ
112+
Ãₚ[m,k] = Aₘₖ
113113
end
114-
C[m,n] = α * Cₘₙ + β * C[m,n]
114+
C[m,n] = α * Cₘₙ + β * C[m,n]
115115
end
116116
end
117117
@inline function alloc_a_pack(K, ::Val{T}) where {T}

test/forward_diff.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
@time @testset "ForwardDiff.jl" begin
3+
m = 5
4+
n = 6
5+
k = 7
6+
7+
A1 = rand(Float64, m, k)
8+
B1 = rand(Float64, k, n)
9+
C1 = rand(Float64, m, n)
10+
11+
A2 = deepcopy(A1)
12+
B2 = deepcopy(B1)
13+
C2 = deepcopy(C1)
14+
15+
α = Float64(2.0)
16+
β = Float64(2.0)
17+
18+
Octavian.matmul!(C1, A1, B1, α, β)
19+
LinearAlgebra.mul!(C2, A2, B2, α, β)
20+
@test C1 C2
21+
22+
@testset "real array from the left" begin
23+
config = ForwardDiff.JacobianConfig(nothing, C1, B1)
24+
I = LinearAlgebra.I(size(B1, 2))
25+
26+
J1 = ForwardDiff.jacobian((C, B) -> Octavian.matmul!(C, A1, B), C1, B1, config)
27+
@test J1 kron(I, A1)
28+
29+
J2 = ForwardDiff.jacobian((C, B) -> LinearAlgebra.mul!(C, A2, B), C2, B2, config)
30+
@test J1 kron(I, A2)
31+
@test J1 J2
32+
end
33+
34+
@testset "real array from the right" begin
35+
# dense and column-major arrays
36+
config = ForwardDiff.JacobianConfig(nothing, C1, A1)
37+
38+
J1 = ForwardDiff.jacobian((C, A) -> Octavian.matmul!(C, A, B1), C1, A1, config)
39+
J2 = ForwardDiff.jacobian((C, A) -> LinearAlgebra.mul!(C, A, B2), C2, A2, config)
40+
@test J1 J2
41+
42+
# transposed arrays
43+
A1new = Matrix(A1')'
44+
A2new = Matrix(A2')'
45+
config = ForwardDiff.JacobianConfig(nothing, C1, A1new)
46+
47+
J1 = ForwardDiff.jacobian((C, A) -> Octavian.matmul!(C, A, B1), C1, A1new, config)
48+
J2 = ForwardDiff.jacobian((C, A) -> LinearAlgebra.mul!(C, A, B2), C2, A2new, config)
49+
@test J1 J2
50+
51+
# direct version using dual numbers
52+
A1dual = zeros(eltype(config), reverse(size(A1))...)
53+
A1dual .= A1'
54+
C1dual = zeros(eltype(config), size(C1)...)
55+
56+
A2dual = deepcopy(A1dual)
57+
C2dual = deepcopy(C1dual)
58+
59+
Octavian.matmul!(C1dual, A1dual', B1)
60+
Octavian.matmul!(C2dual, A2dual', B2)
61+
@test C1dual C2dual
62+
end
63+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import Octavian
22

33
import Aqua
44
import BenchmarkTools
5+
import ForwardDiff
56
import InteractiveUtils
67
import LinearAlgebra
78
import LoopVectorization
@@ -24,6 +25,7 @@ include("integer_division.jl")
2425
include("macrokernels.jl")
2526
include("matmul_coverage.jl")
2627
include("utils.jl")
28+
include("forward_diff.jl")
2729

2830
if !coverage
2931
include("matmul_main.jl")

0 commit comments

Comments
 (0)