|
1 | 1 |
|
2 |
| - |
3 | 2 | struct LoopMulFunc{P,TC,TA,TB,Α,Β,Md,Kd,Nd} <: Function end
|
4 | 3 | function (::LoopMulFunc{P,TC,TA,TB,Α,Β,Md,Kd,Nd})(p::Ptr{UInt}) where {P,TC,TA,TB,Α,Β,Md,Kd,Nd}
|
5 |
| - offset, C = load(p, TC, 2*sizeof(UInt)) |
6 |
| - offset, A = load(p, TA, offset) |
7 |
| - offset, B = load(p, TB, offset) |
8 |
| - offset, α = load(p, Α, offset) |
9 |
| - offset, β = load(p, Β, offset) |
10 |
| - offset, M = load(p, Md, offset) |
11 |
| - offset, K = load(p, Kd, offset) |
12 |
| - offset, N = load(p, Nd, offset) |
13 |
| - _call_loopmul!(C, A, B, α, β, M, K, N, Val{P}()) |
14 |
| - nothing |
| 4 | + offset, C = load(p, TC, 2*sizeof(UInt)) |
| 5 | + offset, A = load(p, TA, offset) |
| 6 | + offset, B = load(p, TB, offset) |
| 7 | + offset, α = load(p, Α, offset) |
| 8 | + offset, β = load(p, Β, offset) |
| 9 | + offset, M = load(p, Md, offset) |
| 10 | + offset, K = load(p, Kd, offset) |
| 11 | + offset, N = load(p, Nd, offset) |
| 12 | + _call_loopmul!(C, A, B, α, β, M, K, N, Val{P}()) |
| 13 | + nothing |
15 | 14 | end
|
16 | 15 | @inline _call_loopmul!(C, A, B, α, β, M, K, N, ::Val{false}) = loopmul!(C, A, B, α, β, M, K, N)
|
17 | 16 | @inline function _call_loopmul!(C::StridedPointer{T}, A, B, α, β, M, K, N, ::Val{true}) where {T}
|
18 |
| - if M*K < first_cache_size(Val(T)) * R₂Default() |
19 |
| - packaloopmul!(C, A, B, α, β, M, K, N) |
20 |
| - return |
21 |
| - else |
22 |
| - matmul_st_only_pack_A!(C, A, B, α, β, M, K, N, W₁Default(), W₂Default(), R₁Default(), R₂Default()) |
23 |
| - return |
24 |
| - end |
| 17 | + if M*K < first_cache_size(Val(T)) * R₂Default() |
| 18 | + packaloopmul!(C, A, B, α, β, M, K, N) |
| 19 | + return |
| 20 | + else |
| 21 | + matmul_st_only_pack_A!(C, A, B, α, β, M, K, N, W₁Default(), W₂Default(), R₁Default(), R₂Default()) |
| 22 | + return |
| 23 | + end |
25 | 24 | end
|
26 | 25 | call_loopmul!(C, A, B, α, β, M, K, N, ::Val{P}) where {P} = _call_loopmul!(C, A, B, α, β, M, K, N, Val{P}())
|
27 | 26 |
|
28 | 27 | struct SyncMulFunc{TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂} <: Function end
|
29 | 28 | function (::SyncMulFunc{TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂})(p::Ptr{UInt}) where {TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂}
|
30 |
| - offset, C = load(p, TC, 2*sizeof(UInt)) |
31 |
| - offset, A = load(p, TA, offset) |
32 |
| - offset, B = load(p, TB, offset) |
33 |
| - offset, α = load(p, Α, offset) |
34 |
| - offset, β = load(p, Β, offset) |
35 |
| - offset, M = load(p, Md, offset) |
36 |
| - offset, K = load(p, Kd, offset) |
37 |
| - offset, N = load(p, Nd, offset) |
38 |
| - offset, atomicp = load(p, Ptr{UInt32}, offset) |
39 |
| - offset, bcachep = load(p, BCP, offset) |
40 |
| - offset, id = load(p, ID, offset) |
41 |
| - offset, total_ids = load(p, TT, offset) |
42 |
| - sync_mul!(C, A, B, α, β, M, K, N, atomicp, bcachep, id, total_ids, StaticFloat64{W₁}(), StaticFloat64{W₂}(), StaticFloat64{R₁}(), StaticFloat64{R₂}()) |
43 |
| - nothing |
| 29 | + offset, C = load(p, TC, 2*sizeof(UInt)) |
| 30 | + offset, A = load(p, TA, offset) |
| 31 | + offset, B = load(p, TB, offset) |
| 32 | + offset, α = load(p, Α, offset) |
| 33 | + offset, β = load(p, Β, offset) |
| 34 | + offset, M = load(p, Md, offset) |
| 35 | + offset, K = load(p, Kd, offset) |
| 36 | + offset, N = load(p, Nd, offset) |
| 37 | + offset, atomicp = load(p, Ptr{UInt32}, offset) |
| 38 | + offset, bcachep = load(p, BCP, offset) |
| 39 | + offset, id = load(p, ID, offset) |
| 40 | + offset, total_ids = load(p, TT, offset) |
| 41 | + sync_mul!(C, A, B, α, β, M, K, N, atomicp, bcachep, id, total_ids, StaticFloat64{W₁}(), StaticFloat64{W₂}(), StaticFloat64{R₁}(), StaticFloat64{R₂}()) |
| 42 | + nothing |
44 | 43 | end
|
45 | 44 |
|
46 | 45 | @generated function cfuncpointer(::T) where {T}
|
47 |
| - precompile(T(), (Ptr{UInt},)) |
48 |
| - quote |
49 |
| - $(Expr(:meta,:inline)) |
50 |
| - @cfunction($(T()), Cvoid, (Ptr{UInt},)) |
51 |
| - end |
| 46 | + precompile(T(), (Ptr{UInt},)) |
| 47 | + quote |
| 48 | + $(Expr(:meta,:inline)) |
| 49 | + @cfunction($(T()), Cvoid, (Ptr{UInt},)) |
| 50 | + end |
52 | 51 | end
|
53 | 52 |
|
54 | 53 | @inline function setup_matmul!(p::Ptr{UInt}, C::TC, A::TA, B::TB, α::Α, β::Β, M::Md, K::Kd, N::Nd, ::Val{P}) where {P,TC,TA,TB,Α,Β,Md,Kd,Nd}
|
55 |
| - offset = store!(p, cfuncpointer(LoopMulFunc{P,TC,TA,TB,Α,Β,Md,Kd,Nd}()), sizeof(UInt)) |
56 |
| - offset = store!(p, C, offset) |
57 |
| - offset = store!(p, A, offset) |
58 |
| - offset = store!(p, B, offset) |
59 |
| - offset = store!(p, α, offset) |
60 |
| - offset = store!(p, β, offset) |
61 |
| - offset = store!(p, M, offset) |
62 |
| - offset = store!(p, K, offset) |
63 |
| - offset = store!(p, N, offset) |
64 |
| - nothing |
| 54 | + offset = store!(p, cfuncpointer(LoopMulFunc{P,TC,TA,TB,Α,Β,Md,Kd,Nd}()), sizeof(UInt)) |
| 55 | + offset = store!(p, C, offset) |
| 56 | + offset = store!(p, A, offset) |
| 57 | + offset = store!(p, B, offset) |
| 58 | + offset = store!(p, α, offset) |
| 59 | + offset = store!(p, β, offset) |
| 60 | + offset = store!(p, M, offset) |
| 61 | + offset = store!(p, K, offset) |
| 62 | + offset = store!(p, N, offset) |
| 63 | + nothing |
65 | 64 | end
|
66 | 65 |
|
67 | 66 | @inline function setup_syncmul!(
|
68 |
| - p::Ptr{UInt}, C::TC, A::TA, B::TB, α::Α, β::Β, M::Md, K::Kd, N::Nd, |
69 |
| - ap::Ptr{UInt32},bcp::BCP,id::ID,tt::TT,::StaticFloat64{W₁},::StaticFloat64{W₂},::StaticFloat64{R₁},::StaticFloat64{R₂} |
| 67 | + p::Ptr{UInt}, C::TC, A::TA, B::TB, α::Α, β::Β, M::Md, K::Kd, N::Nd, |
| 68 | + ap::Ptr{UInt32},bcp::BCP,id::ID,tt::TT,::StaticFloat64{W₁},::StaticFloat64{W₂},::StaticFloat64{R₁},::StaticFloat64{R₂} |
70 | 69 | ) where {TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂}
|
71 |
| - offset = store!(p, cfuncpointer(SyncMulFunc{TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂}()), sizeof(UInt)) |
72 |
| - offset = store!(p, C, offset) |
73 |
| - offset = store!(p, A, offset) |
74 |
| - offset = store!(p, B, offset) |
75 |
| - offset = store!(p, α, offset) |
76 |
| - offset = store!(p, β, offset) |
77 |
| - offset = store!(p, M, offset) |
78 |
| - offset = store!(p, K, offset) |
79 |
| - offset = store!(p, N, offset) |
80 |
| - offset = store!(p, ap, offset) |
81 |
| - offset = store!(p, bcp, offset) |
82 |
| - offset = store!(p, id, offset) |
83 |
| - offset = store!(p, tt, offset) |
84 |
| - nothing |
| 70 | + offset = store!(p, cfuncpointer(SyncMulFunc{TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂}()), sizeof(UInt)) |
| 71 | + offset = store!(p, C, offset) |
| 72 | + offset = store!(p, A, offset) |
| 73 | + offset = store!(p, B, offset) |
| 74 | + offset = store!(p, α, offset) |
| 75 | + offset = store!(p, β, offset) |
| 76 | + offset = store!(p, M, offset) |
| 77 | + offset = store!(p, K, offset) |
| 78 | + offset = store!(p, N, offset) |
| 79 | + offset = store!(p, ap, offset) |
| 80 | + offset = store!(p, bcp, offset) |
| 81 | + offset = store!(p, id, offset) |
| 82 | + offset = store!(p, tt, offset) |
| 83 | + nothing |
85 | 84 | end
|
86 | 85 |
|
87 |
| -function launch_thread_mul!(C, A, B, α, β, M, K, N, tid::Int, ::Val{P}) where {P} |
88 |
| - launch(tid, C, A, B, α, β, M, K, N, Val{P}()) do p, C, A, B, α, β, M, K, N, VP |
89 |
| - setup_matmul!(p, C, A, B, α, β, M, K, N, VP) |
90 |
| - end |
| 86 | +@inline function launch_thread_mul!(C, A, B, α, β, M, K, N, tid::UInt32, ::Val{P}) where {P} |
| 87 | + launch(setup_matmul!, tid, C, A, B, α, β, M, K, N, Val{P}()) |
91 | 88 | end
|
92 |
| -function launch_thread_mul!( |
93 |
| - C, A, B, α, β, M, K, N, ap, bcp, tid, tt,::StaticFloat64{W₁},::StaticFloat64{W₂},::StaticFloat64{R₁},::StaticFloat64{R₂} |
| 89 | +@inline function launch_thread_mul!( |
| 90 | + C, A, B, α, β, M, K, N, ap, bcp, tid, id, tt, ::StaticFloat64{W₁},::StaticFloat64{W₂},::StaticFloat64{R₁},::StaticFloat64{R₂} |
94 | 91 | ) where {W₁,W₂,R₁,R₂}
|
95 |
| - launch(tid+one(tid), C, A, B, α, β, M, K, N, ap, bcp, tid, tt) do p, C, A, B, α, β, M, K, N, ap, bcp, tid, tt |
96 |
| - setup_syncmul!( |
97 |
| - p, C, A, B, α, β, M, K, N, ap, bcp, tid, tt, |
98 |
| - StaticFloat64{W₁}(),StaticFloat64{W₂}(),StaticFloat64{R₁}(),StaticFloat64{R₂}() |
99 |
| - ) |
100 |
| - end |
| 92 | + launch(tid, C, A, B, α, β, M, K, N, ap, bcp, id, tt) do p, C, A, B, α, β, M, K, N, ap, bcp, id, tt |
| 93 | + Base.@_inline_meta |
| 94 | + setup_syncmul!( |
| 95 | + p, C, A, B, α, β, M, K, N, ap, bcp, id, tt, |
| 96 | + StaticFloat64{W₁}(),StaticFloat64{W₂}(),StaticFloat64{R₁}(),StaticFloat64{R₂}() |
| 97 | + ) |
| 98 | + end |
101 | 99 | end
|
102 | 100 |
|
103 | 101 |
|
0 commit comments