1- using System ;
1+ using BrightData . LinearAlgebra . ReadOnly ;
2+ using BrightData . LinearAlgebra . Segments ;
3+ using CommunityToolkit . HighPerformance . Buffers ;
4+ using System ;
25using System . Collections . Generic ;
36using System . Diagnostics ;
47using System . Linq ;
58using System . Numerics ;
69using System . Runtime . CompilerServices ;
710using System . Runtime . InteropServices ;
11+ using System . Runtime . Intrinsics ;
12+ using System . Runtime . Intrinsics . X86 ;
813using System . Threading . Tasks ;
9- using BrightData . LinearAlgebra . ReadOnly ;
10- using BrightData . LinearAlgebra . Segments ;
11- using CommunityToolkit . HighPerformance . Buffers ;
1214
1315namespace BrightData . LinearAlgebra
1416{
@@ -318,7 +320,10 @@ static unsafe IMatrix<T> MultiplyWithThisTransposed(LinearAlgebraProvider<T> lap
318320 fixed ( T * matrixPtr = matrixSpan )
319321 fixed ( T * otherPtr = otherSpan )
320322 fixed ( T * retPtr = retSpan ) {
321- MatrixMultiplyTiled3 ( matrixPtr , otherPtr , lda , rowCount , columnCount , retPtr ) ;
323+ if ( typeof ( T ) == typeof ( float ) && Avx2 . IsSupported && Fma . IsSupported )
324+ MatrixMultiplyFloat ( ( float * ) matrixPtr , ( float * ) otherPtr , lda , ( int ) rowCount , ( int ) columnCount , ( float * ) retPtr ) ;
325+ else
326+ MatrixMultiplyTiled3 ( matrixPtr , otherPtr , lda , rowCount , columnCount , retPtr ) ;
322327 }
323328 }
324329 finally {
@@ -496,7 +501,7 @@ void MultiplyBlock(uint rowStart, uint colStart, uint rowEnd, uint colEnd)
496501 for ( uint jj = j , jLen = Math . Min ( j + L1BlockSize , cols ) ; jj < jLen ; jj ++ ) {
497502 var yPtr = & b [ jj * size ] ;
498503 var vSum = Vector < T > . Zero ;
499- for ( var z = 0 ; z < numVectors ; z ++ )
504+ for ( var z = 0 ; z < numVectors ; z ++ )
500505 vSum += Vector . Load ( xPtr + z * vectorSize ) * Vector . Load ( yPtr + z * vectorSize ) ;
501506
502507 var sum = Vector . Dot ( vSum , Vector < T > . One ) ;
@@ -510,8 +515,7 @@ void MultiplyBlock(uint rowStart, uint colStart, uint rowEnd, uint colEnd)
510515 }
511516
512517 if ( rows * cols >= Consts . MinimumSizeForParallel ) {
513- Parallel . For ( 0 , ( int ) Math . Ceiling ( ( double ) rows / L2BlockSize ) , rowTile =>
514- {
518+ Parallel . For ( 0 , ( int ) Math . Ceiling ( ( double ) rows / L2BlockSize ) , rowTile => {
515519 var rowStart = ( uint ) rowTile * L2BlockSize ;
516520 var rowEnd = rowStart + L2BlockSize ;
517521 for ( var colTile = 0U ; colTile < cols ; colTile += L2BlockSize )
@@ -543,5 +547,100 @@ public override string ToString()
543547
544548 /// <inheritdoc />
545549 protected override IMatrix < T > Create ( MemoryOwner < T > memory ) => new MutableMatrix < T , LAP > ( new ArrayPoolTensorSegment < T > ( memory ) , RowCount , ColumnCount , Lap ) ;
550+
551+ [ MethodImpl ( MethodImplOptions . AggressiveOptimization ) ]
552+ static unsafe void MatrixMultiplyFloat ( float * a , float * b , int K , int M , int N , float * ret )
553+ {
554+ const int BLOCK_SIZE = 64 ;
555+
556+ Parallel . For ( 0 , ( int ) Math . Ceiling ( ( double ) N / BLOCK_SIZE ) , jBlockIndex =>
557+ {
558+ var jjStart = jBlockIndex * BLOCK_SIZE ;
559+ var jjEnd = Math . Min ( jjStart + BLOCK_SIZE , N ) ;
560+ for ( int iiStart = 0 ; iiStart < M ; iiStart += BLOCK_SIZE ) {
561+ int iiEnd = Math . Min ( iiStart + BLOCK_SIZE , M ) ;
562+ ProcessBlock ( a , b , ret , K , M , iiStart , iiEnd , jjStart , jjEnd ) ;
563+ }
564+ } ) ;
565+
566+ [ MethodImpl ( MethodImplOptions . AggressiveOptimization ) ]
567+ static unsafe void ProcessBlock ( float * a , float * b , float * ret , int K , int strideRet , int iStart , int iEnd , int jStart , int jEnd )
568+ {
569+ for ( var jj = jStart ; jj < jEnd ; jj ++ ) {
570+ var ptrB = b + ( long ) jj * K ;
571+ var ii = iStart ;
572+
573+ for ( ; ii < iEnd - 3 ; ii += 4 ) {
574+ var ptrA0 = a + ( long ) ii * K ;
575+ var ptrA1 = a + ( long ) ( ii + 1 ) * K ;
576+ var ptrA2 = a + ( long ) ( ii + 2 ) * K ;
577+ var ptrA3 = a + ( long ) ( ii + 3 ) * K ;
578+
579+ var sum0 = Vector256 < float > . Zero ;
580+ var sum1 = Vector256 < float > . Zero ;
581+ var sum2 = Vector256 < float > . Zero ;
582+ var sum3 = Vector256 < float > . Zero ;
583+
584+ var k = 0 ;
585+ var kLimit = K - 15 ;
586+ for ( ; k < kLimit ; k += 16 ) {
587+ var bVec1 = Avx . LoadVector256 ( ptrB + k ) ;
588+ var bVec2 = Avx . LoadVector256 ( ptrB + k + 8 ) ;
589+
590+ sum0 = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA0 + k ) , bVec1 , sum0 ) ;
591+ sum0 = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA0 + k + 8 ) , bVec2 , sum0 ) ;
592+ sum1 = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA1 + k ) , bVec1 , sum1 ) ;
593+ sum1 = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA1 + k + 8 ) , bVec2 , sum1 ) ;
594+ sum2 = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA2 + k ) , bVec1 , sum2 ) ;
595+ sum2 = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA2 + k + 8 ) , bVec2 , sum2 ) ;
596+ sum3 = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA3 + k ) , bVec1 , sum3 ) ;
597+ sum3 = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA3 + k + 8 ) , bVec2 , sum3 ) ;
598+ }
599+ var s0 = HorizontalAdd ( sum0 ) ;
600+ var s1 = HorizontalAdd ( sum1 ) ;
601+ var s2 = HorizontalAdd ( sum2 ) ;
602+ var s3 = HorizontalAdd ( sum3 ) ;
603+
604+ for ( ; k < K ; k ++ ) {
605+ var bVal = ptrB [ k ] ;
606+ s0 += ptrA0 [ k ] * bVal ;
607+ s1 += ptrA1 [ k ] * bVal ;
608+ s2 += ptrA2 [ k ] * bVal ;
609+ s3 += ptrA3 [ k ] * bVal ;
610+ }
611+
612+ var baseIdx = ( long ) jj * strideRet + ii ;
613+ ret [ baseIdx ] = s0 ;
614+ ret [ baseIdx + 1 ] = s1 ;
615+ ret [ baseIdx + 2 ] = s2 ;
616+ ret [ baseIdx + 3 ] = s3 ;
617+ }
618+
619+ for ( ; ii < iEnd ; ii ++ ) {
620+ var ptrA = a + ( long ) ii * K ;
621+ var vSum = Vector256 < float > . Zero ;
622+ var k = 0 ;
623+ for ( ; k <= K - 8 ; k += 8 )
624+ vSum = Fma . MultiplyAdd ( Avx . LoadVector256 ( ptrA + k ) , Avx . LoadVector256 ( ptrB + k ) , vSum ) ;
625+
626+ var sum = HorizontalAdd ( vSum ) ;
627+ for ( ; k < K ; k ++ )
628+ sum += ptrA [ k ] * ptrB [ k ] ;
629+ ret [ ( long ) jj * strideRet + ii ] = sum ;
630+ }
631+ }
632+ }
633+
634+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
635+ static float HorizontalAdd ( Vector256 < float > v )
636+ {
637+ var vLow = v . GetLower ( ) ;
638+ var vHigh = Avx . ExtractVector128 ( v , 1 ) ;
639+ var v128 = Sse . Add ( vLow , vHigh ) ;
640+ v128 = Sse3 . HorizontalAdd ( v128 , v128 ) ;
641+ v128 = Sse3 . HorizontalAdd ( v128 , v128 ) ;
642+ return v128 . ToScalar ( ) ;
643+ }
644+ }
546645 }
547646}
0 commit comments