Skip to content

Commit adb517f

Browse files
author
Jack Dermody
committed
added improved float matrix multiplication
1 parent e0ebd56 commit adb517f

File tree

1 file changed

+107
-8
lines changed

1 file changed

+107
-8
lines changed

BrightData/LinearAlgebra/MutableMatrix.cs

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
using System;
1+
using BrightData.LinearAlgebra.ReadOnly;
2+
using BrightData.LinearAlgebra.Segments;
3+
using CommunityToolkit.HighPerformance.Buffers;
4+
using System;
25
using System.Collections.Generic;
36
using System.Diagnostics;
47
using System.Linq;
58
using System.Numerics;
69
using System.Runtime.CompilerServices;
710
using System.Runtime.InteropServices;
11+
using System.Runtime.Intrinsics;
12+
using System.Runtime.Intrinsics.X86;
813
using System.Threading.Tasks;
9-
using BrightData.LinearAlgebra.ReadOnly;
10-
using BrightData.LinearAlgebra.Segments;
11-
using CommunityToolkit.HighPerformance.Buffers;
1214

1315
namespace BrightData.LinearAlgebra
1416
{
@@ -318,7 +320,10 @@ static unsafe IMatrix<T> MultiplyWithThisTransposed(LinearAlgebraProvider<T> lap
318320
fixed (T* matrixPtr = matrixSpan)
319321
fixed (T* otherPtr = otherSpan)
320322
fixed (T* retPtr = retSpan) {
321-
MatrixMultiplyTiled3(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr);
323+
if (typeof(T) == typeof(float) && Avx2.IsSupported && Fma.IsSupported)
324+
MatrixMultiplyFloat((float*)matrixPtr, (float*)otherPtr, lda, (int)rowCount, (int)columnCount, (float*)retPtr);
325+
else
326+
MatrixMultiplyTiled3(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr);
322327
}
323328
}
324329
finally {
@@ -496,7 +501,7 @@ void MultiplyBlock(uint rowStart, uint colStart, uint rowEnd, uint colEnd)
496501
for (uint jj = j, jLen = Math.Min(j + L1BlockSize, cols); jj < jLen; jj++) {
497502
var yPtr = &b[jj * size];
498503
var vSum = Vector<T>.Zero;
499-
for (var z = 0; z < numVectors; z++)
504+
for (var z = 0; z < numVectors; z++)
500505
vSum += Vector.Load(xPtr + z * vectorSize) * Vector.Load(yPtr + z * vectorSize);
501506

502507
var sum = Vector.Dot(vSum, Vector<T>.One);
@@ -510,8 +515,7 @@ void MultiplyBlock(uint rowStart, uint colStart, uint rowEnd, uint colEnd)
510515
}
511516

512517
if (rows * cols >= Consts.MinimumSizeForParallel) {
513-
Parallel.For(0, (int)Math.Ceiling((double)rows / L2BlockSize), rowTile =>
514-
{
518+
Parallel.For(0, (int)Math.Ceiling((double)rows / L2BlockSize), rowTile => {
515519
var rowStart = (uint)rowTile * L2BlockSize;
516520
var rowEnd = rowStart + L2BlockSize;
517521
for (var colTile = 0U; colTile < cols; colTile += L2BlockSize)
@@ -543,5 +547,100 @@ public override string ToString()
543547

544548
/// <inheritdoc />
545549
protected override IMatrix<T> Create(MemoryOwner<T> memory) => new MutableMatrix<T, LAP>(new ArrayPoolTensorSegment<T>(memory), RowCount, ColumnCount, Lap);
550+
551+
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
552+
static unsafe void MatrixMultiplyFloat(float* a, float* b, int K, int M, int N, float* ret)
553+
{
554+
const int BLOCK_SIZE = 64;
555+
556+
Parallel.For(0, (int)Math.Ceiling((double)N / BLOCK_SIZE), jBlockIndex =>
557+
{
558+
var jjStart = jBlockIndex * BLOCK_SIZE;
559+
var jjEnd = Math.Min(jjStart + BLOCK_SIZE, N);
560+
for (int iiStart = 0; iiStart < M; iiStart += BLOCK_SIZE) {
561+
int iiEnd = Math.Min(iiStart + BLOCK_SIZE, M);
562+
ProcessBlock(a, b, ret, K, M, iiStart, iiEnd, jjStart, jjEnd);
563+
}
564+
});
565+
566+
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
567+
static unsafe void ProcessBlock(float* a, float* b, float* ret, int K, int strideRet, int iStart, int iEnd, int jStart, int jEnd)
568+
{
569+
for (var jj = jStart; jj < jEnd; jj++) {
570+
var ptrB = b + (long)jj * K;
571+
var ii = iStart;
572+
573+
for (; ii < iEnd - 3; ii += 4) {
574+
var ptrA0 = a + (long)ii * K;
575+
var ptrA1 = a + (long)(ii + 1) * K;
576+
var ptrA2 = a + (long)(ii + 2) * K;
577+
var ptrA3 = a + (long)(ii + 3) * K;
578+
579+
var sum0 = Vector256<float>.Zero;
580+
var sum1 = Vector256<float>.Zero;
581+
var sum2 = Vector256<float>.Zero;
582+
var sum3 = Vector256<float>.Zero;
583+
584+
var k = 0;
585+
var kLimit = K - 15;
586+
for (; k < kLimit; k += 16) {
587+
var bVec1 = Avx.LoadVector256(ptrB + k);
588+
var bVec2 = Avx.LoadVector256(ptrB + k + 8);
589+
590+
sum0 = Fma.MultiplyAdd(Avx.LoadVector256(ptrA0 + k), bVec1, sum0);
591+
sum0 = Fma.MultiplyAdd(Avx.LoadVector256(ptrA0 + k + 8), bVec2, sum0);
592+
sum1 = Fma.MultiplyAdd(Avx.LoadVector256(ptrA1 + k), bVec1, sum1);
593+
sum1 = Fma.MultiplyAdd(Avx.LoadVector256(ptrA1 + k + 8), bVec2, sum1);
594+
sum2 = Fma.MultiplyAdd(Avx.LoadVector256(ptrA2 + k), bVec1, sum2);
595+
sum2 = Fma.MultiplyAdd(Avx.LoadVector256(ptrA2 + k + 8), bVec2, sum2);
596+
sum3 = Fma.MultiplyAdd(Avx.LoadVector256(ptrA3 + k), bVec1, sum3);
597+
sum3 = Fma.MultiplyAdd(Avx.LoadVector256(ptrA3 + k + 8), bVec2, sum3);
598+
}
599+
var s0 = HorizontalAdd(sum0);
600+
var s1 = HorizontalAdd(sum1);
601+
var s2 = HorizontalAdd(sum2);
602+
var s3 = HorizontalAdd(sum3);
603+
604+
for (; k < K; k++) {
605+
var bVal = ptrB[k];
606+
s0 += ptrA0[k] * bVal;
607+
s1 += ptrA1[k] * bVal;
608+
s2 += ptrA2[k] * bVal;
609+
s3 += ptrA3[k] * bVal;
610+
}
611+
612+
var baseIdx = (long)jj * strideRet + ii;
613+
ret[baseIdx] = s0;
614+
ret[baseIdx + 1] = s1;
615+
ret[baseIdx + 2] = s2;
616+
ret[baseIdx + 3] = s3;
617+
}
618+
619+
for (; ii < iEnd; ii++) {
620+
var ptrA = a + (long)ii * K;
621+
var vSum = Vector256<float>.Zero;
622+
var k = 0;
623+
for (; k <= K - 8; k += 8)
624+
vSum = Fma.MultiplyAdd(Avx.LoadVector256(ptrA + k), Avx.LoadVector256(ptrB + k), vSum);
625+
626+
var sum = HorizontalAdd(vSum);
627+
for (; k < K; k++)
628+
sum += ptrA[k] * ptrB[k];
629+
ret[(long)jj * strideRet + ii] = sum;
630+
}
631+
}
632+
}
633+
634+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
635+
static float HorizontalAdd(Vector256<float> v)
636+
{
637+
var vLow = v.GetLower();
638+
var vHigh = Avx.ExtractVector128(v, 1);
639+
var v128 = Sse.Add(vLow, vHigh);
640+
v128 = Sse3.HorizontalAdd(v128, v128);
641+
v128 = Sse3.HorizontalAdd(v128, v128);
642+
return v128.ToScalar();
643+
}
644+
}
546645
}
547646
}

0 commit comments

Comments
 (0)