Skip to content

Commit 8a19d8c

Browse files
authored
fix alpha != 1 (#119)
* fix alpha != 1 * Bump version
1 parent 0cfe997 commit 8a19d8c

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Octavian"
22
uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
33
authors = ["Mason Protter", "Chris Elrod", "Dilum Aluthge", "contributors"]
4-
version = "0.3.6"
4+
version = "0.3.7"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/forward_diff.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,12 @@ _view1(B::AbstractArray{<:Any,3}) = @view(B[1,:,:])
6868
end
6969
end
7070
Pstatic = static(P)
71-
@tturbo for n indices((B,C),3), m indices((A,C),2), p 1:Pstatic, k indices((A,B),(3,2))
72-
C[p+1,m,n] += A[1,m,k] * B[p+1,k,n]
71+
@tturbo for n indices((B,C),3), m indices((A,C),2), p 1:Pstatic
72+
Cₚₘₙ = zero(eltype(C))
73+
for k indices((A,B),(3,2))
74+
Cₚₘₙ += A[1,m,k] * B[p+1,k,n]
75+
end
76+
C[p+1,m,n] = C[p+1,m,n] + α*Cₚₘₙ
7377
end
7478
_C
7579
end

src/matmul.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,13 @@ end
220220
221221
Multiply matrices `A` and `B`.
222222
"""
223-
@inline function matmul(A::AbstractMatrix, B::AbstractVecOrMat)
223+
@inline function matmul(A::AbstractMatrix, B::AbstractVecOrMat, α, β)
224224
C, (M,K,N) = alloc_matmul_product(A, B)
225-
matmul!(C, A, B, One(), Zero(), nothing, (M,K,N), ArrayInterface.contiguous_axis(C))
225+
matmul!(C, A, B, α, β, nothing, (M,K,N), ArrayInterface.contiguous_axis(C))
226226
return C
227227
end
228+
@inline matmul(A::AbstractMatrix, B::AbstractVecOrMat) = matmul(A, B, One(), Zero())
229+
@inline matmul(A::AbstractMatrix, B::AbstractVecOrMat, α) = matmul(A, B, α, Zero())
228230

229231
"""
230232
matmul!(C, A, B[, α, β, max_threads])

test/forward_diff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ randdual(x, ::Val{N}=Val(3)) where {N} = ForwardDiff.Dual(x, ntuple(_ -> randn()
6565
@testset "two dual arrays" begin
6666
A1d = randdual.(A1)
6767
B1d = randdual.(B1)
68-
@test reinterpret(Float64, Octavian.matmul(A1d, B1d)) reinterpret(Float64, A1d * B1d)
68+
@test reinterpret(Float64, Octavian.matmul(A1d, B1d, 1.3)) reinterpret(Float64, (A1d * B1d) .* 1.3)
6969
@test reinterpret(Float64, Octavian.matmul(@view(A1d[begin:end-1,:]), B1d)) reinterpret(Float64, @view(A1d[begin:end-1,:]) * B1d)
7070
end
7171
end

0 commit comments

Comments
 (0)