Skip to content

Commit 0cfe997

Browse files
authored
Add support for both arrays containing dual numbers. (#118)
1 parent daefe40 commit 0cfe997

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Octavian"
22
uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
33
authors = ["Mason Protter", "Chris Elrod", "Dilum Aluthge", "contributors"]
4-
version = "0.3.5"
4+
version = "0.3.6"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/forward_diff.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
# we can avoid the reshape and call the standard method
2727
A = reinterpret(T, _A)
2828
C = reinterpret(T, _C)
29-
_matmul!(C, A, B, α, β, nthread, MKN)
29+
_matmul!(C, A, B, α, β, nthread, nothing)
3030
else
3131
# we cannot use the standard method directly
3232
A = real_rep(_A)
@@ -43,3 +43,33 @@ end
4343

4444
_C
4545
end
46+
47+
_view1(B::AbstractMatrix) = @view(B[1,:])
48+
_view1(B::AbstractArray{<:Any,3}) = @view(B[1,:,:])
49+
@inline function _matmul!(_C::AbstractVecOrMat{DualT}, _A::AbstractMatrix{DualT}, _B::AbstractVecOrMat{DualT},
50+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing) where {TAG, T, P, DualT <: ForwardDiff.Dual{TAG, T, P}}
51+
A = real_rep(_A)
52+
C = real_rep(_C)
53+
B = real_rep(_B)
54+
if all((ArrayInterface.is_dense(_C), ArrayInterface.is_column_major(_C),
55+
ArrayInterface.is_dense(_A), ArrayInterface.is_column_major(_A)))
56+
# we can avoid the reshape and call the standard method
57+
Ar = reinterpret(T, _A)
58+
Cr = reinterpret(T, _C)
59+
_matmul!(Cr, Ar, _view1(B), α, β, nthread, nothing)
60+
else
61+
# we cannot use the standard method directly
62+
@tturbo for n indices((C, B), 3), m indices((C, A), 2), l in indices((C, A), 1)
63+
Cₗₘₙ = zero(eltype(C))
64+
for k indices((A, B), (3, 2))
65+
Cₗₘₙ += A[l, m, k] * B[1, k, n]
66+
end
67+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
68+
end
69+
end
70+
Pstatic = static(P)
71+
@tturbo for n indices((B,C),3), m indices((A,C),2), p 1:Pstatic, k indices((A,B),(3,2))
72+
C[p+1,m,n] += A[1,m,k] * B[p+1,k,n]
73+
end
74+
_C
75+
end

test/forward_diff.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11

2+
randdual(x, ::Val{N}=Val(3)) where {N} = ForwardDiff.Dual(x, ntuple(_ -> randn(), Val(N))...)
23
@time @testset "ForwardDiff.jl" begin
3-
m = 5
4-
n = 6
5-
k = 7
4+
m = 53
5+
n = 63
6+
k = 73
67

78
A1 = rand(Float64, m, k)
89
B1 = rand(Float64, k, n)
@@ -60,4 +61,11 @@
6061
Octavian.matmul!(C2dual, A2dual', B2)
6162
@test C1dual C2dual
6263
end
64+
65+
@testset "two dual arrays" begin
66+
A1d = randdual.(A1)
67+
B1d = randdual.(B1)
68+
@test reinterpret(Float64, Octavian.matmul(A1d, B1d)) reinterpret(Float64, A1d * B1d)
69+
@test reinterpret(Float64, Octavian.matmul(@view(A1d[begin:end-1,:]), B1d)) reinterpret(Float64, @view(A1d[begin:end-1,:]) * B1d)
70+
end
6371
end

0 commit comments

Comments
 (0)