@@ -591,7 +591,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
591591
592592 BLASLONG nthreads = args -> nthreads ;
593593
594- BLASLONG width , i , j , k , js ;
594+ BLASLONG width , width_n , i , j , k , js ;
595595 BLASLONG m , n , n_from , n_to ;
596596 int mode ;
597597#if defined(DYNAMIC_ARCH )
@@ -740,18 +740,25 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
740740 /* Partition (a step of) n into nthreads regions */
741741 range_N [0 ] = js ;
742742 num_parts = 0 ;
743- while (n > 0 ){
744- width = blas_quickdivide (n + nthreads - num_parts - 1 , nthreads - num_parts );
745- if (width < switch_ratio ) {
746- width = switch_ratio ;
743+ for (j = 0 ; j < nthreads_n ; j ++ ){
744+ width_n = blas_quickdivide (n + nthreads_n - j - 1 , nthreads_n - j );
745+ n -= width_n ;
746+ for (i = 0 ; i < nthreads_m ; i ++ ){
747+ width = blas_quickdivide (width_n + nthreads_m - i - 1 , nthreads_m - i );
748+ if (width < switch_ratio ) {
749+ width = switch_ratio ;
750+ }
751+ width = round_up (width_n , width , GEMM_PREFERED_SIZE );
752+
753+ width_n -= width ;
754+ if (width_n < 0 ) {
755+ width = width + width_n ;
756+ width_n = 0 ;
757+ }
758+ range_N [num_parts + 1 ] = range_N [num_parts ] + width ;
759+
760+ num_parts ++ ;
747761 }
748- width = round_up (n , width , GEMM_PREFERED_SIZE );
749-
750- n -= width ;
751- if (n < 0 ) width = width + n ;
752- range_N [num_parts + 1 ] = range_N [num_parts ] + width ;
753-
754- num_parts ++ ;
755762 }
756763 for (j = num_parts ; j < MAX_CPU_NUMBER ; j ++ ) {
757764 range_N [j + 1 ] = range_N [num_parts ];
0 commit comments