1
1
2
- first_effective_cache (:: Type{T} ) where {T} = StaticInt {FIRST__CACHE_SIZE} () ÷ static_sizeof (T)
3
- second_effective_cache (:: Type{T} ) where {T} = StaticInt {SECOND_CACHE_SIZE} () ÷ static_sizeof (T)
4
2
5
3
function block_sizes (:: Type{T} , _α, _β, R₁, R₂) where {T}
6
- W = VectorizationBase . pick_vector_width_val (T)
4
+ W = pick_vector_width_val (T)
7
5
α = _α * W
8
6
β = _β * W
9
- L₁ₑ = first_effective_cache (T) * R₁
10
- L₂ₑ = second_effective_cache (T) * R₂
7
+ L₁ₑ = first_cache_size (T) * R₁
8
+ L₂ₑ = second_cache_size (T) * R₂
11
9
block_sizes (W, α, β, L₁ₑ, L₂ₑ)
12
10
end
13
11
function block_sizes (W, α, β, L₁ₑ, L₂ₑ)
14
- MᵣW = StaticInt {mᵣ} () * W
12
+ mᵣ, nᵣ = matmul_params ()
13
+ MᵣW = mᵣ * W
15
14
16
15
Mc = floortostaticint (√ (L₁ₑ)*√ (L₁ₑ* β + L₂ₑ* α)/√ (L₂ₑ) / MᵣW) * MᵣW
17
16
Kc = roundtostaticint (√ (L₁ₑ)*√ (L₂ₑ)/√ (L₁ₑ* β + L₂ₑ* α))
18
- Nc = floortostaticint (√ (L₂ₑ)*√ (L₁ₑ* β + L₂ₑ* α)/√ (L₁ₑ) / StaticInt {nᵣ} ()) * StaticInt {nᵣ} ()
17
+ Nc = floortostaticint (√ (L₂ₑ)*√ (L₁ₑ* β + L₂ₑ* α)/√ (L₁ₑ) / nᵣ) * nᵣ
19
18
20
19
Mc, Kc, Nc
21
20
end
22
21
function block_sizes (:: Type{T} ) where {T}
23
- block_sizes (T, StaticFloat { W₁Default} (), StaticFloat { W₂Default} (), StaticFloat { R₁Default} (), StaticFloat { R₂Default} ())
22
+ block_sizes (T, W₁Default (), W₂Default (), R₁Default (), R₂Default ())
24
23
end
25
24
26
25
"""
@@ -159,11 +158,11 @@ Note that for synchronization on `B`, all threads must have the same values for
159
158
independently of `M`, this algorithm guarantees all threads are on the same page.
160
159
"""
161
160
@inline function solve_block_sizes (:: Type{T} , M, K, N, _α, _β, R₂, R₃, Wfactor) where {T}
162
- W = VectorizationBase . pick_vector_width_val (T)
161
+ W = pick_vector_width_val (T)
163
162
α = _α * W
164
163
β = _β * W
165
- L₁ₑ = first_effective_cache (T) * R₂
166
- L₂ₑ = second_effective_cache (T) * R₃
164
+ L₁ₑ = first_cache_size (T) * R₂
165
+ L₂ₑ = second_cache_size (T) * R₃
167
166
168
167
# Nc_init = round(Int, √(L₂ₑ)*√(α * L₂ₑ + β * L₁ₑ)/√(L₁ₑ))
169
168
Nc_init⁻¹ = √ (L₁ₑ) / (√ (L₂ₑ)*√ (α * L₂ₑ + β * L₁ₑ))
@@ -178,11 +177,11 @@ independently of `M`, this algorithm guarantees all threads are on the same page
178
177
end
179
178
# Takes Nc, calcs Mc and Kc
180
179
@inline function solve_McKc (:: Type{T} , M, K, Nc, _α, _β, R₂, R₃, Wfactor) where {T}
181
- W = VectorizationBase . pick_vector_width_val (T)
180
+ W = pick_vector_width_val (T)
182
181
α = _α * W
183
182
β = _β * W
184
- L₁ₑ = first_effective_cache (T) * R₂
185
- L₂ₑ = second_effective_cache (T) * R₃
183
+ L₁ₑ = first_cache_size (T) * R₂
184
+ L₂ₑ = second_cache_size (T) * R₃
186
185
187
186
Kc_init⁻¹ = Base. FastMath. max_fast (√ (α/ L₁ₑ), Nc* inv (L₂ₑ))
188
187
Kiter = cldapproxi (K, Kc_init⁻¹) # approximate `ceil`
@@ -201,27 +200,28 @@ end
201
200
"""
202
201
find_first_acceptable(M, W)
203
202
204
- Finds first combination of `Miter` and `Niter` that doesn't make `M` too small while producing `Miter * Niter = NUM_CORES `.
203
+ Finds first combination of `Miter` and `Niter` that doesn't make `M` too small while producing `Miter * Niter = num_cores() `.
205
204
This would be awkard if there are computers with prime numbers of cores. I should probably consider that possibility at some point.
206
205
"""
207
206
@inline function find_first_acceptable (M, W)
208
- Mᵣ = StaticInt {mᵣ} () * W
209
- for (miter,niter) ∈ CORE_FACTORS
210
- if miter * ((MᵣW_mul_factor - One ()) * Mᵣ) ≤ M + (W + W)
207
+ Mᵣ, Nᵣ = matmul_params ()
208
+ factors = calc_factors ()
209
+ for (miter, niter) ∈ factors
210
+ if miter * ((MᵣW_mul_factor () - One ()) * Mᵣ) ≤ M + (W + W)
211
211
return miter, niter
212
212
end
213
213
end
214
- last (CORE_FACTORS )
214
+ last (factors )
215
215
end
216
216
"""
217
217
divide_blocks(M, Ntotal, _nspawn, W)
218
218
219
219
Splits both `M` and `N` into blocks when trying to spawn a large number of threads relative to the size of the matrices.
220
220
"""
221
221
@inline function divide_blocks (M, Ntotal, _nspawn, W)
222
- _nspawn == NUM_CORES && return find_first_acceptable (M, W)
223
-
224
- Miter = clamp (div_fast (M, W* StaticInt {mᵣ} () * MᵣW_mul_factor), 1 , _nspawn)
222
+ _nspawn == num_cores () && return find_first_acceptable (M, W)
223
+ mᵣ, nᵣ = matmul_params ()
224
+ Miter = clamp (div_fast (M, W* mᵣ * MᵣW_mul_factor () ), 1 , _nspawn)
225
225
nspawn = div_fast (_nspawn, Miter)
226
226
if (nspawn ≤ 1 ) & (Miter < _nspawn)
227
227
# rebalance Miter
0 commit comments