1
+ #include " BlasLapackInterface.hpp"
2
+ #include " backendHelpers.hpp"
3
+
4
+ extern " C"
5
+ {
6
+
7
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
8
+ #define LVARRAY_SGEMM LVARRAY_LAPACK_FORTRAN_MANGLE ( sgemm )
9
+ void LVARRAY_SGEMM(
10
+ char const * TRANSA,
11
+ char const * TRANSB,
12
+ int const * M,
13
+ int const * N,
14
+ int const * K,
15
+ float const * ALPHA,
16
+ float const * A,
17
+ int const * LDA,
18
+ float const * B,
19
+ int const * LDB,
20
+ float const * BETA,
21
+ float * C,
22
+ int const * LDC );
23
+
24
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
25
+ #define LVARRAY_DGEMM LVARRAY_LAPACK_FORTRAN_MANGLE ( dgemm )
26
+ void LVARRAY_DGEMM(
27
+ char const * TRANSA,
28
+ char const * TRANSB,
29
+ int const * M,
30
+ int const * N,
31
+ int const * K,
32
+ double const * ALPHA,
33
+ double const * A,
34
+ int const * LDA,
35
+ double const * B,
36
+ int const * LDB,
37
+ double const * BETA,
38
+ double * C,
39
+ int const * LDC );
40
+
41
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
42
+ #define LVARRAY_CGEMM LVARRAY_LAPACK_FORTRAN_MANGLE ( cgemm )
43
+ void LVARRAY_CGEMM(
44
+ char const * TRANSA,
45
+ char const * TRANSB,
46
+ int const * M,
47
+ int const * N,
48
+ int const * K,
49
+ std::complex< float > const * ALPHA,
50
+ std::complex< float > const * A,
51
+ int const * LDA,
52
+ std::complex< float > const * B,
53
+ int const * LDB,
54
+ std::complex< float > const * BETA,
55
+ std::complex< float > * C,
56
+ int const * LDC );
57
+
58
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
59
+ #define LVARRAY_ZGEMM LVARRAY_LAPACK_FORTRAN_MANGLE ( zgemm )
60
+ void LVARRAY_ZGEMM(
61
+ char const * TRANSA,
62
+ char const * TRANSB,
63
+ int const * M,
64
+ int const * N,
65
+ int const * K,
66
+ std::complex< double > const * ALPHA,
67
+ std::complex< double > const * A,
68
+ int const * LDA,
69
+ std::complex< double > const * B,
70
+ int const * LDB,
71
+ std::complex< double > const * BETA,
72
+ std::complex< double > * C,
73
+ int const * LDC );
74
+
75
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
76
+ #define LVARRAY_SGESV LVARRAY_LAPACK_FORTRAN_MANGLE ( sgesv )
77
+ void LVARRAY_SGESV(
78
+ int const * N,
79
+ int const * NRHS,
80
+ float * A,
81
+ int const * LDA,
82
+ int * IPIV,
83
+ float * B,
84
+ int const * LDB,
85
+ int * INFO );
86
+
87
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
88
+ #define LVARRAY_DGESV LVARRAY_LAPACK_FORTRAN_MANGLE ( dgesv )
89
+ void LVARRAY_DGESV(
90
+ int const * N,
91
+ int const * NRHS,
92
+ double * A,
93
+ int const * LDA,
94
+ int * IPIV,
95
+ double * B,
96
+ int const * LDB,
97
+ int * INFO );
98
+
99
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
100
+ #define LVARRAY_CGESV LVARRAY_LAPACK_FORTRAN_MANGLE ( cgesv )
101
+ void LVARRAY_CGESV(
102
+ int const * N,
103
+ int const * NRHS,
104
+ std::complex< float > * A,
105
+ int const * LDA,
106
+ int * IPIV,
107
+ std::complex< float > * B,
108
+ int const * LDB,
109
+ int * INFO );
110
+
111
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
112
+ #define LVARRAY_ZGESV LVARRAY_LAPACK_FORTRAN_MANGLE ( zgesv )
113
+ void LVARRAY_ZGESV(
114
+ int const * N,
115
+ int const * NRHS,
116
+ std::complex< double > * A,
117
+ int const * LDA,
118
+ int * IPIV,
119
+ std::complex< double > * B,
120
+ int const * LDB,
121
+ int * INFO );
122
+
123
+ } // extern "C"
124
+
125
+ namespace LvArray
126
+ {
127
+ namespace dense
128
+ {
129
+
130
+ char toLapackChar ( Operation const op )
131
+ {
132
+ if ( op == Operation::NO_OP ) return ' N' ;
133
+ if ( op == Operation::TRANSPOSE ) return ' T' ;
134
+ if ( op == Operation::ADJOINT ) return ' C' ;
135
+
136
+ LVARRAY_ERROR ( " Unknown operation: " << int ( op ) );
137
+ return ' \0 ' ;
138
+ }
139
+
140
+
141
+ template < typename T >
142
+ void BlasLapackInterface< T >::gemm(
143
+ Operation opA,
144
+ Operation opB,
145
+ T const alpha,
146
+ Matrix< T const > const & A,
147
+ Matrix< T const > const & B,
148
+ T const beta,
149
+ Matrix< T > const & C )
150
+ {
151
+ char const TRANSA = toLapackChar ( opA );
152
+ char const TRANSB = toLapackChar ( opB );
153
+ int const M = C.sizes [ 0 ];
154
+ int const N = C.sizes [ 1 ];
155
+ int const K = opA == Operation::NO_OP ? A.sizes [ 1 ] : A.sizes [ 0 ];
156
+ int const LDA = std::max ( std::ptrdiff_t { 1 }, A.strides [ 1 ] );
157
+ int const LDB = std::max ( std::ptrdiff_t { 1 }, B.strides [ 1 ] );
158
+ int const LDC = std::max ( std::ptrdiff_t { 1 }, C.strides [ 1 ] );
159
+
160
+ TypeDispatch< T >::dispatch ( LVARRAY_SGEMM, LVARRAY_DGEMM, LVARRAY_CGEMM, LVARRAY_ZGEMM,
161
+ &TRANSA,
162
+ &TRANSB,
163
+ &M,
164
+ &N,
165
+ &K,
166
+ &alpha,
167
+ A.data ,
168
+ &LDA,
169
+ B.data ,
170
+ &LDB,
171
+ &beta,
172
+ C.data ,
173
+ &LDC );
174
+ }
175
+
176
+
177
+ template < typename T >
178
+ void BlasLapackInterface< T >::gesv(
179
+ Matrix< T > const & A,
180
+ Matrix< T > const & B,
181
+ Vector< int > const & pivots )
182
+ {
183
+ int const N = A.sizes [ 0 ];
184
+ int const NRHS = B.sizes [ 1 ];
185
+ int const LDA = A.strides [ 1 ];
186
+ int const LDB = B.strides [ 1 ];
187
+ int INFO = 0 ;
188
+
189
+ TypeDispatch< T >::dispatch ( LVARRAY_SGESV, LVARRAY_DGESV, LVARRAY_CGESV, LVARRAY_ZGESV,
190
+ &N,
191
+ &NRHS,
192
+ A.data ,
193
+ &LDA,
194
+ pivots.data ,
195
+ B.data ,
196
+ &LDB,
197
+ &INFO );
198
+
199
+ LVARRAY_ERROR_IF ( INFO < 0 , " The " << -INFO << " -th argument had an illegal value." );
200
+ LVARRAY_ERROR_IF ( INFO > 0 , " The factorization has been completed but U( " << INFO - 1 << " , " << INFO - 1 <<
201
+ " ) is exactly zero so the solution could not be computed." );
202
+ }
203
+
204
+ template class BlasLapackInterface < float >;
205
+ template class BlasLapackInterface < double >;
206
+ template class BlasLapackInterface < std::complex< float > >;
207
+ template class BlasLapackInterface < std::complex< double > >;
208
+
209
+ } // namespace dense
210
+ } // namespace LvArray
0 commit comments