|
26 | 26 | # we can avoid the reshape and call the standard method
|
27 | 27 | A = reinterpret(T, _A)
|
28 | 28 | C = reinterpret(T, _C)
|
29 |
| - _matmul!(C, A, B, α, β, nthread, MKN) |
| 29 | + _matmul!(C, A, B, α, β, nthread, nothing) |
30 | 30 | else
|
31 | 31 | # we cannot use the standard method directly
|
32 | 32 | A = real_rep(_A)
|
|
43 | 43 |
|
44 | 44 | _C
|
45 | 45 | end
|
| 46 | + |
| 47 | +_view1(B::AbstractMatrix) = @view(B[1,:]) |
| 48 | +_view1(B::AbstractArray{<:Any,3}) = @view(B[1,:,:]) |
| 49 | +@inline function _matmul!(_C::AbstractVecOrMat{DualT}, _A::AbstractMatrix{DualT}, _B::AbstractVecOrMat{DualT}, |
| 50 | + α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing) where {TAG, T, P, DualT <: ForwardDiff.Dual{TAG, T, P}} |
| 51 | + A = real_rep(_A) |
| 52 | + C = real_rep(_C) |
| 53 | + B = real_rep(_B) |
| 54 | + if all((ArrayInterface.is_dense(_C), ArrayInterface.is_column_major(_C), |
| 55 | + ArrayInterface.is_dense(_A), ArrayInterface.is_column_major(_A))) |
| 56 | + # we can avoid the reshape and call the standard method |
| 57 | + Ar = reinterpret(T, _A) |
| 58 | + Cr = reinterpret(T, _C) |
| 59 | + _matmul!(Cr, Ar, _view1(B), α, β, nthread, nothing) |
| 60 | + else |
| 61 | + # we cannot use the standard method directly |
| 62 | + @tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2), l in indices((C, A), 1) |
| 63 | + Cₗₘₙ = zero(eltype(C)) |
| 64 | + for k ∈ indices((A, B), (3, 2)) |
| 65 | + Cₗₘₙ += A[l, m, k] * B[1, k, n] |
| 66 | + end |
| 67 | + C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n] |
| 68 | + end |
| 69 | + end |
| 70 | + Pstatic = static(P) |
| 71 | + @tturbo for n ∈ indices((B,C),3), m ∈ indices((A,C),2), p ∈ 1:Pstatic, k ∈ indices((A,B),(3,2)) |
| 72 | + C[p+1,m,n] += A[1,m,k] * B[p+1,k,n] |
| 73 | + end |
| 74 | + _C |
| 75 | +end |
0 commit comments