@@ -11,19 +11,17 @@ function block_sizes(::Val{T}, _α, _β, R₁, R₂) where {T}
11
11
block_sizes (Val (T), W, α, β, L₁ₑ, L₂ₑ)
12
12
end
13
13
function block_sizes (:: Val{T} , W, α, β, L₁ₑ, L₂ₑ) where {T}
14
- mᵣnᵣ = matmul_params (Val (T))
15
- mᵣ = getfield (mᵣnᵣ, 1 )
16
- nᵣ = getfield (mᵣnᵣ, 2 )
14
+ mᵣ, nᵣ = matmul_params (Val (T))
17
15
MᵣW = mᵣ * W
18
-
19
- Mc = floortostaticint (√ (L₁ₑ)* √ (L₁ₑ* β + L₂ₑ* α) / √ (L₂ₑ) / StaticFloat64 (MᵣW)) * MᵣW
20
- Kc = roundtostaticint (√ (L₁ₑ)* √ (L₂ₑ)/ √ (L₁ₑ* β + L₂ₑ* α))
21
- Nc = floortostaticint (√ (L₂ₑ)* √ (L₁ₑ* β + L₂ₑ* α) / √ (L₁ₑ) / StaticFloat64 (nᵣ)) * nᵣ
22
-
16
+
17
+ Mc = floortostaticint (√ (L₁ₑ) * √ (L₁ₑ * β + L₂ₑ * α) / √ (L₂ₑ) / StaticFloat64 (MᵣW)) * MᵣW
18
+ Kc = roundtostaticint (√ (L₁ₑ) * √ (L₂ₑ) / √ (L₁ₑ * β + L₂ₑ * α))
19
+ Nc = floortostaticint (√ (L₂ₑ) * √ (L₁ₑ * β + L₂ₑ * α) / √ (L₁ₑ) / StaticFloat64 (nᵣ)) * nᵣ
20
+
23
21
Mc, Kc, Nc
24
22
end
25
23
function block_sizes (:: Val{T} ) where {T}
26
- block_sizes (Val (T), W₁Default (), W₂Default (), R₁Default (), R₂Default ())
24
+ block_sizes (Val (T), W₁Default (), W₂Default (), R₁Default (), R₂Default ())
27
25
end
28
26
29
27
"""
@@ -48,12 +46,12 @@ This is meant to specify roughly the requested amount of blocks, and return rela
48
46
This method is used fairly generally.
49
47
"""
50
48
@inline function split_m (M, _Mblocks, W)
51
- Miters = cld_fast (M, W)
52
- Mblocks = min (_Mblocks, Miters)
53
- Miter_per_block, Mrem = divrem_fast (Miters, Mblocks)
54
- Mbsize = Miter_per_block * W
55
- Mremfinal = M - Mbsize* (Mblocks- 1 ) - Mrem * W
56
- Mbsize, Mrem, Mremfinal, Mblocks
49
+ Miters = cld_fast (M, W)
50
+ Mblocks = min (_Mblocks, Miters)
51
+ Miter_per_block, Mrem = divrem_fast (Miters, Mblocks)
52
+ Mbsize = Miter_per_block * W
53
+ Mremfinal = M - Mbsize * (Mblocks - 1 ) - Mrem * W
54
+ Mbsize, Mrem, Mremfinal, Mblocks
57
55
end
58
56
59
57
"""
@@ -162,33 +160,36 @@ Note that for synchronization on `B`, all threads must have the same values for
162
160
independently of `M`, this algorithm guarantees all threads are on the same page.
163
161
"""
164
162
@inline function solve_block_sizes (:: Val{T} , M, K, N, _α, _β, R₂, R₃, Wfactor) where {T}
165
- W = pick_vector_width (T)
166
- α = _α * W
167
- β = _β * W
168
- L₁ₑ = first_cache_size (Val (T)) * R₂
169
- L₂ₑ = second_cache_size (Val (T)) * R₃
163
+ W = pick_vector_width (T)
164
+ α = _α * W
165
+ β = _β * W
166
+ L₁ₑ = first_cache_size (Val (T)) * R₂
167
+ L₂ₑ = second_cache_size (Val (T)) * R₃
170
168
171
- # Nc_init = round(Int, √(L₂ₑ)*√(α * L₂ₑ + β * L₁ₑ)/√(L₁ₑ))
172
- Nc_init⁻¹ = √ (L₁ₑ) / (√ (L₂ₑ)*√ (α * L₂ₑ + β * L₁ₑ))
173
-
174
- Niter = cldapproxi (N, Nc_init⁻¹) # approximate `ceil`
175
- Nblock, Nrem = divrem_fast (N, Niter)
176
- Nblock_Nrem = Nblock + One ()# (Nrem > 0)
169
+ # Nc_init = round(Int, √(L₂ₑ)*√(α * L₂ₑ + β * L₁ₑ)/√(L₁ₑ))
170
+ Nc_init⁻¹ = √ (L₁ₑ) / (√ (L₂ₑ) * √ (α * L₂ₑ + β * L₁ₑ))
177
171
178
- ((Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter)) = solve_McKc (Val (T), M, K, Nblock_Nrem, _α, _β, R₂, R₃, Wfactor)
179
-
180
- (Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter), promote (Nblock, Nblock_Nrem, Nrem, Niter)
172
+ Niter = cldapproxi (N, Nc_init⁻¹) # approximate `ceil`
173
+ Nblock, Nrem = divrem_fast (N, Niter)
174
+ Nblock_Nrem = Nblock + One ()# (Nrem > 0)
175
+
176
+ ((Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter)) =
177
+ solve_McKc (Val (T), M, K, Nblock_Nrem, _α, _β, R₂, R₃, Wfactor)
178
+
179
+ (Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter),
180
+ (Kblock, Kblock_Krem, Krem, Kiter),
181
+ promote (Nblock, Nblock_Nrem, Nrem, Niter)
181
182
end
182
183
# Takes Nc, calcs Mc and Kc
183
184
@inline function solve_McKc (:: Val{T} , M, K, Nc, _α, _β, R₂, R₃, Wfactor) where {T}
184
185
W = pick_vector_width (T)
185
186
Wfloat = StaticFloat64 (W)
186
187
α = _α * Wfloat
187
- β = _β * Wfloat
188
- L₁ₑ = first_cache_size (Val (T)) * R₂
188
+ # β = _β * Wfloat
189
+ L₁ₑ = first_cache_size (Val (T)) * R₂
189
190
L₂ₑ = second_cache_size (Val (T)) * R₃
190
191
191
- Kc_init⁻¹ = Base. FastMath. max_fast (√ (α/ L₁ₑ), Nc* inv (L₂ₑ))
192
+ Kc_init⁻¹ = Base. FastMath. max_fast (√ (α / L₁ₑ), Nc * inv (L₂ₑ))
192
193
Kiter = cldapproxi (K, Kc_init⁻¹) # approximate `ceil`
193
194
Kblock, Krem = divrem_fast (K, Kiter)
194
195
Kblock_Krem = Kblock + One ()
202
203
Mblocks, Mblocks_rem = divrem_fast (M, Mᵣ)
203
204
Miter, Mrem = divrem_fast (Mblocks, Mc_init_base)
204
205
if Miter == 0
205
- return (0 , 0 , Int (M):: Int , 0 , 1 ), Kblock_summary
206
+ return (0 , 0 , Int (M):: Int , 0 , 1 ), Kblock_summary
206
207
elseif Miter > Mrem
207
208
Mblock_Mrem = Mbsize + Mᵣ
208
209
Mremfinal = Mbsize + Mblocks_rem
221
222
end
222
223
end
223
224
224
- @inline cldapproxi (n, d⁻¹) = Base. fptosi (Int, Base. FastMath. add_fast (Base. FastMath. mul_fast (n, d⁻¹), 0.9999999999999432 )) # approximate `ceil`
225
+ @inline cldapproxi (n, d⁻¹) = Base. fptosi (
226
+ Int,
227
+ Base. FastMath. add_fast (Base. FastMath. mul_fast (n, d⁻¹), 0.9999999999999432 ),
228
+ ) # approximate `ceil`
225
229
# @inline divapproxi(n, d⁻¹) = Base.fptosi(Int, Base.FastMath.mul_fast(n, d⁻¹)) # approximate `div`
226
230
227
231
"""
@@ -231,14 +235,14 @@ Finds first combination of `Miter` and `Niter` that doesn't make `M` too small w
231
235
This would be awkard if there are computers with prime numbers of cores. I should probably consider that possibility at some point.
232
236
"""
233
237
@inline function find_first_acceptable (:: Val{T} , M, W) where {T}
234
- Mᵣ, Nᵣ = matmul_params (Val (T))
235
- factors = calc_factors ()
236
- for (miter, niter) ∈ factors
237
- if miter * (StaticInt (2 ) * Mᵣ * W) ≤ M + (W + W)
238
- return miter, niter
239
- end
238
+ Mᵣ, _ = matmul_params (Val (T))
239
+ factors = calc_factors ()
240
+ for (miter, niter) ∈ factors
241
+ if miter * (StaticInt (2 ) * Mᵣ * W) ≤ M + (W + W)
242
+ return miter, niter
240
243
end
241
- last (factors)
244
+ end
245
+ last (factors)
242
246
end
243
247
"""
244
248
divide_blocks(M, Ntotal, _nspawn, W)
@@ -247,8 +251,8 @@ Splits both `M` and `N` into blocks when trying to spawn a large number of threa
247
251
"""
248
252
@inline function divide_blocks (:: Val{T} , M, Ntotal, _nspawn, W) where {T}
249
253
_nspawn == num_cores () && return find_first_acceptable (Val (T), M, W)
250
- mᵣ, nᵣ = matmul_params (Val (T))
251
- Miter = clamp (div_fast (M, W* mᵣ * MᵣW_mul_factor ()), 1 , _nspawn)
254
+ mᵣ, _ = matmul_params (Val (T))
255
+ Miter = clamp (div_fast (M, W * mᵣ * MᵣW_mul_factor ()), 1 , _nspawn)
252
256
nspawn = div_fast (_nspawn, Miter)
253
257
if (nspawn ≤ 1 ) & (Miter < _nspawn)
254
258
# rebalance Miter
0 commit comments