Skip to content

Commit 3725d1a

Browse files
committed
Started testing gemm, need to work out some stuff.
1 parent 3165719 commit 3725d1a

File tree

13 files changed

+695
-154
lines changed

13 files changed

+695
-154
lines changed

scripts/uberenv/packages/lvarray/package.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,10 @@ class Lvarray(CMakePackage, CudaPackage):
6565
variant('addr2line', default=True,
6666
description='Build support for addr2line.')
6767

68-
<<<<<<< HEAD
6968
variant('tpl_build_type', default='none', description='TPL build type',
7069
values=('Debug', 'Release', 'RelWithDebInfo', 'MinSizeRel', 'none'))
7170

72-
73-
# conflicts('~lapack', when='+magma')
74-
=======
7571
conflicts('~lapack', when='+magma')
76-
>>>>>>> cde43f2 (Building and compiling with MAGMA. GPU not yet working, think it's something to do with the new workspaces.)
7772

7873
depends_on('[email protected]:', when='@0.2.0:', type='build')
7974

@@ -114,6 +109,7 @@ class Lvarray(CMakePackage, CudaPackage):
114109
depends_on('umpire build_type={}'.format(bt))
115110
depends_on('chai build_type={}'.format(bt), when='+chai')
116111
depends_on('caliper build_type={}'.format(bt), when='+caliper')
112+
depends_on('magma build_type={}'.format(bt), when='+magma')
117113

118114
phases = ['hostconfig', 'cmake', 'build', 'install']
119115

scripts/uberenv/spack_configs/toss_4_x86_64_ib/packages.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@ packages:
33
target: [default]
44
compiler: [gcc, clang, intel]
55
providers:
6-
blas: [intel-mkl]
7-
lapack: [intel-mkl]
6+
blas: [intel-oneapi-mkl]
7+
lapack: [intel-oneapi-mkl]
88

9-
intel-mkl:
9+
intel-oneapi-mkl:
1010
buildable: False
1111
externals:
12-
- spec: intel-mkl@2020.0.166 threads=openmp
13-
prefix: /usr/tce/packages/mkl/mkl-2020.0/
12+
- spec: intel-oneapi-mkl@2022.1.0
13+
prefix: /usr/tce/backend/installations/linux-rhel8-x86_64/intel-19.0.4/intel-oneapi-mkl-2022.1.0-sksz67twjxftvwchnagedk36gf7plkrp/
1414

1515
cmake:
1616
buildable: False

src/dense/BlasLapackInterface.cpp

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
#include "BlasLapackInterface.hpp"
2+
#include "backendHelpers.hpp"
3+
4+
extern "C"
5+
{
6+
7+
////////////////////////////////////////////////////////////////////////////////////////////////////
8+
#define LVARRAY_SGEMM LVARRAY_LAPACK_FORTRAN_MANGLE( sgemm )
9+
void LVARRAY_SGEMM(
10+
char const * TRANSA,
11+
char const * TRANSB,
12+
int const * M,
13+
int const * N,
14+
int const * K,
15+
float const * ALPHA,
16+
float const * A,
17+
int const * LDA,
18+
float const * B,
19+
int const * LDB,
20+
float const * BETA,
21+
float * C,
22+
int const * LDC );
23+
24+
////////////////////////////////////////////////////////////////////////////////////////////////////
25+
#define LVARRAY_DGEMM LVARRAY_LAPACK_FORTRAN_MANGLE( dgemm )
26+
void LVARRAY_DGEMM(
27+
char const * TRANSA,
28+
char const * TRANSB,
29+
int const * M,
30+
int const * N,
31+
int const * K,
32+
double const * ALPHA,
33+
double const * A,
34+
int const * LDA,
35+
double const * B,
36+
int const * LDB,
37+
double const * BETA,
38+
double * C,
39+
int const * LDC );
40+
41+
////////////////////////////////////////////////////////////////////////////////////////////////////
42+
#define LVARRAY_CGEMM LVARRAY_LAPACK_FORTRAN_MANGLE( cgemm )
43+
void LVARRAY_CGEMM(
44+
char const * TRANSA,
45+
char const * TRANSB,
46+
int const * M,
47+
int const * N,
48+
int const * K,
49+
std::complex< float > const * ALPHA,
50+
std::complex< float > const * A,
51+
int const * LDA,
52+
std::complex< float > const * B,
53+
int const * LDB,
54+
std::complex< float > const * BETA,
55+
std::complex< float > * C,
56+
int const * LDC );
57+
58+
////////////////////////////////////////////////////////////////////////////////////////////////////
59+
#define LVARRAY_ZGEMM LVARRAY_LAPACK_FORTRAN_MANGLE( zgemm )
60+
void LVARRAY_ZGEMM(
61+
char const * TRANSA,
62+
char const * TRANSB,
63+
int const * M,
64+
int const * N,
65+
int const * K,
66+
std::complex< double > const * ALPHA,
67+
std::complex< double > const * A,
68+
int const * LDA,
69+
std::complex< double > const * B,
70+
int const * LDB,
71+
std::complex< double > const * BETA,
72+
std::complex< double > * C,
73+
int const * LDC );
74+
75+
////////////////////////////////////////////////////////////////////////////////////////////////////
76+
#define LVARRAY_SGESV LVARRAY_LAPACK_FORTRAN_MANGLE( sgesv )
77+
void LVARRAY_SGESV(
78+
int const * N,
79+
int const * NRHS,
80+
float * A,
81+
int const * LDA,
82+
int * IPIV,
83+
float * B,
84+
int const * LDB,
85+
int * INFO );
86+
87+
////////////////////////////////////////////////////////////////////////////////////////////////////
88+
#define LVARRAY_DGESV LVARRAY_LAPACK_FORTRAN_MANGLE( dgesv )
89+
void LVARRAY_DGESV(
90+
int const * N,
91+
int const * NRHS,
92+
double * A,
93+
int const * LDA,
94+
int * IPIV,
95+
double * B,
96+
int const * LDB,
97+
int * INFO );
98+
99+
////////////////////////////////////////////////////////////////////////////////////////////////////
100+
#define LVARRAY_CGESV LVARRAY_LAPACK_FORTRAN_MANGLE( cgesv )
101+
void LVARRAY_CGESV(
102+
int const * N,
103+
int const * NRHS,
104+
std::complex< float > * A,
105+
int const * LDA,
106+
int * IPIV,
107+
std::complex< float > * B,
108+
int const * LDB,
109+
int * INFO );
110+
111+
////////////////////////////////////////////////////////////////////////////////////////////////////
112+
#define LVARRAY_ZGESV LVARRAY_LAPACK_FORTRAN_MANGLE( zgesv )
113+
void LVARRAY_ZGESV(
114+
int const * N,
115+
int const * NRHS,
116+
std::complex< double > * A,
117+
int const * LDA,
118+
int * IPIV,
119+
std::complex< double > * B,
120+
int const * LDB,
121+
int * INFO );
122+
123+
} // extern "C"
124+
125+
namespace LvArray
126+
{
127+
namespace dense
128+
{
129+
130+
char toLapackChar( Operation const op )
131+
{
132+
if( op == Operation::NO_OP ) return 'N';
133+
if( op == Operation::TRANSPOSE ) return 'T';
134+
if( op == Operation::ADJOINT ) return 'C';
135+
136+
LVARRAY_ERROR( "Unknown operation: " << int( op ) );
137+
return '\0';
138+
}
139+
140+
141+
template< typename T >
142+
void BlasLapackInterface< T >::gemm(
143+
Operation opA,
144+
Operation opB,
145+
T const alpha,
146+
Matrix< T const > const & A,
147+
Matrix< T const > const & B,
148+
T const beta,
149+
Matrix< T > const & C )
150+
{
151+
char const TRANSA = toLapackChar( opA );
152+
char const TRANSB = toLapackChar( opB );
153+
int const M = C.sizes[ 0 ];
154+
int const N = C.sizes[ 1 ];
155+
int const K = opA == Operation::NO_OP ? A.sizes[ 1 ] : A.sizes[ 0 ];
156+
int const LDA = std::max( std::ptrdiff_t{ 1 }, A.strides[ 1 ] );
157+
int const LDB = std::max( std::ptrdiff_t{ 1 }, B.strides[ 1 ] );
158+
int const LDC = std::max( std::ptrdiff_t{ 1 }, C.strides[ 1 ] );
159+
160+
TypeDispatch< T >::dispatch( LVARRAY_SGEMM, LVARRAY_DGEMM, LVARRAY_CGEMM, LVARRAY_ZGEMM,
161+
&TRANSA,
162+
&TRANSB,
163+
&M,
164+
&N,
165+
&K,
166+
&alpha,
167+
A.data,
168+
&LDA,
169+
B.data,
170+
&LDB,
171+
&beta,
172+
C.data,
173+
&LDC );
174+
}
175+
176+
177+
template< typename T >
178+
void BlasLapackInterface< T >::gesv(
179+
Matrix< T > const & A,
180+
Matrix< T > const & B,
181+
Vector< int > const & pivots )
182+
{
183+
int const N = A.sizes[ 0 ];
184+
int const NRHS = B.sizes[ 1 ];
185+
int const LDA = A.strides[ 1 ];
186+
int const LDB = B.strides[ 1 ];
187+
int INFO = 0;
188+
189+
TypeDispatch< T >::dispatch( LVARRAY_SGESV, LVARRAY_DGESV, LVARRAY_CGESV, LVARRAY_ZGESV,
190+
&N,
191+
&NRHS,
192+
A.data,
193+
&LDA,
194+
pivots.data,
195+
B.data,
196+
&LDB,
197+
&INFO );
198+
199+
LVARRAY_ERROR_IF( INFO < 0, "The " << -INFO << "-th argument had an illegal value." );
200+
LVARRAY_ERROR_IF( INFO > 0, "The factorization has been completed but U( " << INFO - 1 << ", " << INFO - 1 <<
201+
" ) is exactly zero so the solution could not be computed." );
202+
}
203+
204+
template class BlasLapackInterface< float >;
205+
template class BlasLapackInterface< double >;
206+
template class BlasLapackInterface< std::complex< float > >;
207+
template class BlasLapackInterface< std::complex< double > >;
208+
209+
} // namespace dense
210+
} // namespace LvArray

src/dense/BlasLapackInterface.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include "common.hpp"
4+
5+
namespace LvArray
6+
{
7+
namespace dense
8+
{
9+
10+
template< typename T >
11+
struct BlasLapackInterface
12+
{
13+
static constexpr MemorySpace MEMORY_SPACE = MemorySpace::host;
14+
15+
static void gemm(
16+
Operation opA,
17+
Operation opB,
18+
T const alpha,
19+
Matrix< T const > const & A,
20+
Matrix< T const > const & B,
21+
T const beta,
22+
Matrix< T > const & C );
23+
24+
static void gesv(
25+
Matrix< T > const & A,
26+
Matrix< T > const & B,
27+
Vector< int > const & pivots );
28+
};
29+
30+
} // namespace dense
31+
} // namespace LvArray

src/dense/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
set( lvarraydense_headers
22
common.hpp
33
backendHelpers.hpp
4-
eigenDecomposition.hpp
5-
linearSolve.hpp
4+
BlasLapackInterface.hpp
65
)
76

87
set( lvarraydense_sources
98
common.cpp
10-
eigenDecomposition.cpp
11-
linearSolve.cpp
9+
BlasLapackInterface.cpp
1210
)
1311

1412
set( dependencies lvarray ${lvarray_dependencies} blas lapack )

src/dense/backendHelpers.hpp

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,82 @@
11
#pragma once
22

3-
#if defined( LVARRAY_USE_MAGMA )
4-
#include <magma.h>
5-
#endif
3+
#include <complex>
64

75
/// This macro provide a flexible interface for Fortran naming convention for compiled objects
86
// #ifdef FORTRAN_MANGLE_NO_UNDERSCORE
97
#define LVARRAY_LAPACK_FORTRAN_MANGLE( name ) name
108
// #else
119
// #define LVARRAY_LAPACK_FORTRAN_MANGLE( name ) name ## _
12-
// #endif
10+
// #endif
11+
12+
namespace LvArray
13+
{
14+
namespace dense
15+
{
16+
17+
template< typename T >
18+
struct TypeDispatch
19+
{};
20+
21+
template<>
22+
struct TypeDispatch< float >
23+
{
24+
template< typename F_FLOAT, typename F_DOUBLE, typename F_CFLOAT, typename F_CDOUBLE, typename ... ARGS >
25+
static constexpr auto dispatch(
26+
F_FLOAT && fFloat,
27+
F_DOUBLE &&,
28+
F_CFLOAT &&,
29+
F_CDOUBLE &&,
30+
ARGS && ... args )
31+
{
32+
return fFloat( std::forward< ARGS >( args ) ... );
33+
}
34+
};
35+
36+
template<>
37+
struct TypeDispatch< double >
38+
{
39+
template< typename F_FLOAT, typename F_DOUBLE, typename F_CFLOAT, typename F_CDOUBLE, typename ... ARGS >
40+
static constexpr auto dispatch(
41+
F_FLOAT &&,
42+
F_DOUBLE && fDouble,
43+
F_CFLOAT &&,
44+
F_CDOUBLE &&,
45+
ARGS && ... args )
46+
{
47+
return fDouble( std::forward< ARGS >( args ) ... );
48+
}
49+
};
50+
51+
template<>
52+
struct TypeDispatch< std::complex< float > >
53+
{
54+
template< typename F_FLOAT, typename F_DOUBLE, typename F_CFLOAT, typename F_CDOUBLE, typename ... ARGS >
55+
static constexpr auto dispatch(
56+
F_FLOAT &&,
57+
F_DOUBLE &&,
58+
F_CFLOAT && fCFloat,
59+
F_CDOUBLE &&,
60+
ARGS && ... args )
61+
{
62+
return fCFloat( std::forward< ARGS >( args ) ... );
63+
}
64+
};
65+
66+
template<>
67+
struct TypeDispatch< std::complex< double > >
68+
{
69+
template< typename F_FLOAT, typename F_DOUBLE, typename F_CFLOAT, typename F_CDOUBLE, typename ... ARGS >
70+
static constexpr auto dispatch(
71+
F_FLOAT &&,
72+
F_DOUBLE &&,
73+
F_CFLOAT &&,
74+
F_CDOUBLE && fCDouble,
75+
ARGS && ... args )
76+
{
77+
return fCDouble( std::forward< ARGS >( args ) ... );
78+
}
79+
};
80+
81+
} // namespace dense
82+
} // namespace LvArray

0 commit comments

Comments
 (0)