GetFEM++  5.3
gmm_blas_interface.h
Go to the documentation of this file.
1 /* -*- c++ -*- (enables emacs c++ mode) */
2 /*===========================================================================
3 
4  Copyright (C) 2003-2017 Yves Renard
5 
6  This file is a part of GetFEM++
7 
8  GetFEM++ is free software; you can redistribute it and/or modify it
9  under the terms of the GNU Lesser General Public License as published
10  by the Free Software Foundation; either version 3 of the License, or
11  (at your option) any later version along with the GCC Runtime Library
12  Exception either version 3.1 or (at your option) any later version.
13  This program is distributed in the hope that it will be useful, but
14  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15  or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
16  License and GCC Runtime Library Exception for more details.
17  You should have received a copy of the GNU Lesser General Public License
18  along with this program; if not, write to the Free Software Foundation,
19  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA.
20 
21  As a special exception, you may use this file as it is a part of a free
22  software library without restriction. Specifically, if other files
23  instantiate templates or use macros or inline functions from this file,
24  or you compile this file and link it with other files to produce an
25  executable, this file does not by itself cause the resulting executable
26  to be covered by the GNU Lesser General Public License. This exception
27  does not however invalidate any other reasons why the executable file
28  might be covered by the GNU Lesser General Public License.
29 
30 ===========================================================================*/
31 
32 /**@file gmm_blas_interface.h
33  @author Yves Renard <Yves.Renard@insa-lyon.fr>
34  @date October 7, 2003.
35  @brief gmm interface for fortran BLAS.
36 */
37 
38 #if defined(GETFEM_USES_BLAS) || defined(GMM_USES_BLAS) \
39  || defined(GMM_USES_LAPACK) || defined(GMM_USES_ATLAS)
40 
41 #ifndef GMM_BLAS_INTERFACE_H
42 #define GMM_BLAS_INTERFACE_H
43 
44 #include "gmm_blas.h"
45 #include "gmm_interface.h"
46 #include "gmm_matrix.h"
47 
48 namespace gmm {
49 
50  // Use ./configure --enable-blas-interface to activate this interface.
51 
52 #define GMMLAPACK_TRACE(f)
53  // #define GMMLAPACK_TRACE(f) cout << "function " << f << " called" << endl;
54 
55  /* ********************************************************************* */
56  /* Operations interfaced for T = float, double, std::complex<float> */
57  /* or std::complex<double> : */
58  /* */
59  /* vect_norm2(std::vector<T>) */
60  /* */
61  /* vect_sp(std::vector<T>, std::vector<T>) */
62  /* vect_sp(scaled(std::vector<T>), std::vector<T>) */
63  /* vect_sp(std::vector<T>, scaled(std::vector<T>)) */
64  /* vect_sp(scaled(std::vector<T>), scaled(std::vector<T>)) */
65  /* */
66  /* vect_hp(std::vector<T>, std::vector<T>) */
67  /* vect_hp(scaled(std::vector<T>), std::vector<T>) */
68  /* vect_hp(std::vector<T>, scaled(std::vector<T>)) */
69  /* vect_hp(scaled(std::vector<T>), scaled(std::vector<T>)) */
70  /* */
71  /* add(std::vector<T>, std::vector<T>) */
72  /* add(scaled(std::vector<T>, a), std::vector<T>) */
73  /* */
74  /* mult(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>) */
75  /* mult(transposed(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
76  /* mult(dense_matrix<T>, transposed(dense_matrix<T>), dense_matrix<T>) */
77  /* mult(transposed(dense_matrix<T>), transposed(dense_matrix<T>), */
78  /* dense_matrix<T>) */
79  /* mult(conjugated(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
80  /* mult(dense_matrix<T>, conjugated(dense_matrix<T>), dense_matrix<T>) */
81  /* mult(conjugated(dense_matrix<T>), conjugated(dense_matrix<T>), */
82  /* dense_matrix<T>) */
83  /* */
84  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>) */
85  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
86  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
87  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
88  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
89  /* std::vector<T>) */
90  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
91  /* std::vector<T>) */
92  /* */
93  /* mult_add(dense_matrix<T>, std::vector<T>, std::vector<T>) */
94  /* mult_add(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
95  /* mult_add(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
96  /* mult_add(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
97  /* mult_add(transposed(dense_matrix<T>), scaled(std::vector<T>), */
98  /* std::vector<T>) */
99  /* mult_add(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
100  /* std::vector<T>) */
101  /* */
102  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>, std::vector<T>) */
103  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>, */
104  /* std::vector<T>) */
105  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>, */
106  /* std::vector<T>) */
107  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>, */
108  /* std::vector<T>) */
109  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
110  /* std::vector<T>, std::vector<T>) */
111  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
112  /* std::vector<T>, std::vector<T>) */
113  /* mult(dense_matrix<T>, std::vector<T>, scaled(std::vector<T>), */
114  /* std::vector<T>) */
115  /* mult(transposed(dense_matrix<T>), std::vector<T>, */
116  /* scaled(std::vector<T>), std::vector<T>) */
117  /* mult(conjugated(dense_matrix<T>), std::vector<T>, */
118  /* scaled(std::vector<T>), std::vector<T>) */
119  /* mult(dense_matrix<T>, scaled(std::vector<T>), scaled(std::vector<T>), */
120  /* std::vector<T>) */
121  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
122  /* scaled(std::vector<T>), std::vector<T>) */
123  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
124  /* scaled(std::vector<T>), std::vector<T>) */
125  /* */
126  /* lower_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
127  /* upper_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
128  /* lower_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
129  /* upper_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
130  /* lower_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
131  /* upper_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
132  /* */
133  /* rank_one_update(dense_matrix<T>, std::vector<T>, std::vector<T>) */
134  /* rank_one_update(dense_matrix<T>, scaled(std::vector<T>), */
135  /* std::vector<T>) */
136  /* rank_one_update(dense_matrix<T>, std::vector<T>, */
137  /* scaled(std::vector<T>)) */
138  /* */
139  /* ********************************************************************* */
140 
141  /* ********************************************************************* */
142  /* Basic defines. */
143  /* ********************************************************************* */
144 
145 # define BLAS_S float
146 # define BLAS_D double
147 # define BLAS_C std::complex<float>
148 # define BLAS_Z std::complex<double>
149 
150  /* ********************************************************************* */
151  /* BLAS functions used. */
152  /* ********************************************************************* */
153  extern "C" {
154  void daxpy_(const long *n, const double *alpha, const double *x,
155  const long *incx, double *y, const long *incy);
156  void dgemm_(const char *tA, const char *tB, const long *m,
157  const long *n, const long *k, const double *alpha,
158  const double *A, const long *ldA, const double *B,
159  const long *ldB, const double *beta, double *C,
160  const long *ldC);
161  void sgemm_(...); void cgemm_(...); void zgemm_(...);
162  void sgemv_(...); void dgemv_(...); void cgemv_(...); void zgemv_(...);
163  void strsv_(...); void dtrsv_(...); void ctrsv_(...); void ztrsv_(...);
164  void saxpy_(...); /*void daxpy_(...); */void caxpy_(...); void zaxpy_(...);
165  BLAS_S sdot_ (...); BLAS_D ddot_ (...);
166  BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
167  BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
168  BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
169  BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
170  void sger_(...); void dger_(...); void cgerc_(...); void zgerc_(...);
171  }
172 
173 #if 1
174 
175  /* ********************************************************************* */
176  /* vect_norm2(x). */
177  /* ********************************************************************* */
178 
179 # define nrm2_interface(param1, trans1, blas_name, base_type) \
180  inline number_traits<base_type >::magnitude_type \
181  vect_norm2(param1(base_type)) { \
182  GMMLAPACK_TRACE("nrm2_interface"); \
183  long inc(1), n(long(vect_size(x))); trans1(base_type); \
184  return blas_name(&n, &x[0], &inc); \
185  }
186 
187 # define nrm2_p1(base_type) const std::vector<base_type > &x
188 # define nrm2_trans1(base_type)
189 
190  nrm2_interface(nrm2_p1, nrm2_trans1, snrm2_ , BLAS_S)
191  nrm2_interface(nrm2_p1, nrm2_trans1, dnrm2_ , BLAS_D)
192  nrm2_interface(nrm2_p1, nrm2_trans1, scnrm2_, BLAS_C)
193  nrm2_interface(nrm2_p1, nrm2_trans1, dznrm2_, BLAS_Z)
194 
195  /* ********************************************************************* */
196  /* vect_sp(x, y). */
197  /* ********************************************************************* */
198 
199 # define dot_interface(param1, trans1, mult1, param2, trans2, mult2, \
200  blas_name, base_type) \
201  inline base_type vect_sp(param1(base_type), param2(base_type)) { \
202  GMMLAPACK_TRACE("dot_interface"); \
203  trans1(base_type); trans2(base_type); long inc(1), n(long(vect_size(y)));\
204  return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \
205  }
206 
207 # define dot_p1(base_type) const std::vector<base_type > &x
208 # define dot_trans1(base_type)
209 # define dot_p1_s(base_type) \
210  const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
211 # define dot_trans1_s(base_type) \
212  std::vector<base_type > &x = \
213  const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
214  base_type a(x_.r)
215 
216 # define dot_p2(base_type) const std::vector<base_type > &y
217 # define dot_trans2(base_type)
218 # define dot_p2_s(base_type) \
219  const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_
220 # define dot_trans2_s(base_type) \
221  std::vector<base_type > &y = \
222  const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \
223  base_type b(y_.r)
224 
225  dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2, dot_trans2, (BLAS_S),
226  sdot_ , BLAS_S)
227  dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2, dot_trans2, (BLAS_D),
228  ddot_ , BLAS_D)
229  dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2, dot_trans2, (BLAS_C),
230  cdotu_, BLAS_C)
231  dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2, dot_trans2, (BLAS_Z),
232  zdotu_, BLAS_Z)
233 
234  dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_S),
235  sdot_ ,BLAS_S)
236  dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_D),
237  ddot_ ,BLAS_D)
238  dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_C),
239  cdotu_,BLAS_C)
240  dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_Z),
241  zdotu_,BLAS_Z)
242 
243  dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2_s, dot_trans2_s, b*,
244  sdot_ ,BLAS_S)
245  dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2_s, dot_trans2_s, b*,
246  ddot_ ,BLAS_D)
247  dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2_s, dot_trans2_s, b*,
248  cdotu_,BLAS_C)
249  dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2_s, dot_trans2_s, b*,
250  zdotu_,BLAS_Z)
251 
252  dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,sdot_ ,
253  BLAS_S)
254  dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,ddot_ ,
255  BLAS_D)
256  dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,cdotu_,
257  BLAS_C)
258  dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,zdotu_,
259  BLAS_Z)
260 
261 
262  /* ********************************************************************* */
263  /* vect_hp(x, y). */
264  /* ********************************************************************* */
265 
266 # define dotc_interface(param1, trans1, mult1, param2, trans2, mult2, \
267  blas_name, base_type) \
268  inline base_type vect_hp(param1(base_type), param2(base_type)) { \
269  GMMLAPACK_TRACE("dotc_interface"); \
270  trans1(base_type); trans2(base_type); long inc(1), n(long(vect_size(y)));\
271  return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \
272  }
273 
274 # define dotc_p1(base_type) const std::vector<base_type > &x
275 # define dotc_trans1(base_type)
276 # define dotc_p1_s(base_type) \
277  const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
278 # define dotc_trans1_s(base_type) \
279  std::vector<base_type > &x = \
280  const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
281  base_type a(x_.r)
282 
283 # define dotc_p2(base_type) const std::vector<base_type > &y
284 # define dotc_trans2(base_type)
285 # define dotc_p2_s(base_type) \
286  const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_
287 # define dotc_trans2_s(base_type) \
288  std::vector<base_type > &y = \
289  const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \
290  base_type b(gmm::conj(y_.r))
291 
292  dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2, dotc_trans2,
293  (BLAS_S),sdot_ ,BLAS_S)
294  dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2, dotc_trans2,
295  (BLAS_D),ddot_ ,BLAS_D)
296  dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2, dotc_trans2,
297  (BLAS_C),cdotc_,BLAS_C)
298  dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2, dotc_trans2,
299  (BLAS_Z),zdotc_,BLAS_Z)
300 
301  dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
302  (BLAS_S),sdot_, BLAS_S)
303  dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
304  (BLAS_D),ddot_ , BLAS_D)
305  dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
306  (BLAS_C),cdotc_, BLAS_C)
307  dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
308  (BLAS_Z),zdotc_, BLAS_Z)
309 
310  dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2_s, dotc_trans2_s,
311  b*,sdot_ , BLAS_S)
312  dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2_s, dotc_trans2_s,
313  b*,ddot_ , BLAS_D)
314  dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2_s, dotc_trans2_s,
315  b*,cdotc_, BLAS_C)
316  dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2_s, dotc_trans2_s,
317  b*,zdotc_, BLAS_Z)
318 
319  dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,sdot_ ,
320  BLAS_S)
321  dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,ddot_ ,
322  BLAS_D)
323  dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,cdotc_,
324  BLAS_C)
325  dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,zdotc_,
326  BLAS_Z)
327 
328  /* ********************************************************************* */
329  /* add(x, y). */
330  /* ********************************************************************* */
331 
332 # define axpy_interface(param1, trans1, blas_name, base_type) \
333  inline void add(param1(base_type), std::vector<base_type > &y) { \
334  GMMLAPACK_TRACE("axpy_interface"); \
335  long inc(1), n(long(vect_size(y))); trans1(base_type); \
336  if (n == 0) return; \
337  blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
338  }
339 
340 # define axpy_p1(base_type) const std::vector<base_type > &x
341 # define axpy_trans1(base_type) base_type a(1)
342 # define axpy_p1_s(base_type) \
343  const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
344 # define axpy_trans1_s(base_type) \
345  std::vector<base_type > &x = \
346  const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
347  base_type a(x_.r)
348 
349  axpy_interface(axpy_p1, axpy_trans1, saxpy_, BLAS_S)
350  axpy_interface(axpy_p1, axpy_trans1, daxpy_, BLAS_D)
351  axpy_interface(axpy_p1, axpy_trans1, caxpy_, BLAS_C)
352  axpy_interface(axpy_p1, axpy_trans1, zaxpy_, BLAS_Z)
353 
354  axpy_interface(axpy_p1_s, axpy_trans1_s, saxpy_, BLAS_S)
355  axpy_interface(axpy_p1_s, axpy_trans1_s, daxpy_, BLAS_D)
356  axpy_interface(axpy_p1_s, axpy_trans1_s, caxpy_, BLAS_C)
357  axpy_interface(axpy_p1_s, axpy_trans1_s, zaxpy_, BLAS_Z)
358 
359 
360  /* ********************************************************************* */
361  /* mult_add(A, x, z). */
362  /* ********************************************************************* */
363 
364 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
365  base_type, orien) \
366  inline void mult_add_spec(param1(base_type), param2(base_type), \
367  std::vector<base_type > &z, orien) { \
368  GMMLAPACK_TRACE("gemv_interface"); \
369  trans1(base_type); trans2(base_type); base_type beta(1); \
370  long m(long(mat_nrows(A))), lda(m), n(long(mat_ncols(A))), inc(1); \
371  if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
372  &beta, &z[0], &inc); \
373  else gmm::clear(z); \
374  }
375 
376  // First parameter
377 # define gem_p1_n(base_type) const dense_matrix<base_type > &A
378 # define gem_trans1_n(base_type) const char t = 'N'
379 # define gem_p1_t(base_type) \
380  const transposed_col_ref<dense_matrix<base_type > *> &A_
381 # define gem_trans1_t(base_type) dense_matrix<base_type > &A = \
382  const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
383  const char t = 'T'
384 # define gem_p1_tc(base_type) \
385  const transposed_col_ref<const dense_matrix<base_type > *> &A_
386 # define gem_p1_c(base_type) \
387  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_
388 # define gem_trans1_c(base_type) dense_matrix<base_type > &A = \
389  const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
390  const char t = 'C'
391 
392  // second parameter
393 # define gemv_p2_n(base_type) const std::vector<base_type > &x
394 # define gemv_trans2_n(base_type) base_type alpha(1)
395 # define gemv_p2_s(base_type) \
396  const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
397 # define gemv_trans2_s(base_type) std::vector<base_type > &x = \
398  const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
399  base_type alpha(x_.r)
400 
401  // Z <- AX + Z.
402  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
403  BLAS_S, col_major)
404  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
405  BLAS_D, col_major)
406  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
407  BLAS_C, col_major)
408  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
409  BLAS_Z, col_major)
410 
411  // Z <- transposed(A)X + Z.
412  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
413  BLAS_S, row_major)
414  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
415  BLAS_D, row_major)
416  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
417  BLAS_C, row_major)
418  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
419  BLAS_Z, row_major)
420 
421  // Z <- transposed(const A)X + Z.
422  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
423  BLAS_S, row_major)
424  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
425  BLAS_D, row_major)
426  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
427  BLAS_C, row_major)
428  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
429  BLAS_Z, row_major)
430 
431  // Z <- conjugated(A)X + Z.
432  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
433  BLAS_S, row_major)
434  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
435  BLAS_D, row_major)
436  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
437  BLAS_C, row_major)
438  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
439  BLAS_Z, row_major)
440 
441  // Z <- A scaled(X) + Z.
442  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
443  BLAS_S, col_major)
444  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
445  BLAS_D, col_major)
446  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
447  BLAS_C, col_major)
448  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
449  BLAS_Z, col_major)
450 
451  // Z <- transposed(A) scaled(X) + Z.
452  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
453  BLAS_S, row_major)
454  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
455  BLAS_D, row_major)
456  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
457  BLAS_C, row_major)
458  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
459  BLAS_Z, row_major)
460 
461  // Z <- transposed(const A) scaled(X) + Z.
462  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
463  BLAS_S, row_major)
464  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
465  BLAS_D, row_major)
466  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
467  BLAS_C, row_major)
468  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
469  BLAS_Z, row_major)
470 
471  // Z <- conjugated(A) scaled(X) + Z.
472  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
473  BLAS_S, row_major)
474  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
475  BLAS_D, row_major)
476  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
477  BLAS_C, row_major)
478  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
479  BLAS_Z, row_major)
480 
481 
482  /* ********************************************************************* */
483  /* mult(A, x, y). */
484  /* ********************************************************************* */
485 
486 # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \
487  base_type, orien) \
488  inline void mult_spec(param1(base_type), param2(base_type), \
489  std::vector<base_type > &z, orien) { \
490  GMMLAPACK_TRACE("gemv_interface2"); \
491  trans1(base_type); trans2(base_type); base_type beta(0); \
492  long m(long(mat_nrows(A))), lda(m), n(long(mat_ncols(A))), inc(1); \
493  if (m && n) \
494  blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \
495  &z[0], &inc); \
496  else gmm::clear(z); \
497  }
498 
499  // Y <- AX.
500  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
501  BLAS_S, col_major)
502  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
503  BLAS_D, col_major)
504  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
505  BLAS_C, col_major)
506  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
507  BLAS_Z, col_major)
508 
509  // Y <- transposed(A)X.
510  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
511  BLAS_S, row_major)
512  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
513  BLAS_D, row_major)
514  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
515  BLAS_C, row_major)
516  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
517  BLAS_Z, row_major)
518 
519  // Y <- transposed(const A)X.
520  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
521  BLAS_S, row_major)
522  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
523  BLAS_D, row_major)
524  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
525  BLAS_C, row_major)
526  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
527  BLAS_Z, row_major)
528 
529  // Y <- conjugated(A)X.
530  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
531  BLAS_S, row_major)
532  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
533  BLAS_D, row_major)
534  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
535  BLAS_C, row_major)
536  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
537  BLAS_Z, row_major)
538 
539  // Y <- A scaled(X).
540  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
541  BLAS_S, col_major)
542  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
543  BLAS_D, col_major)
544  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
545  BLAS_C, col_major)
546  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
547  BLAS_Z, col_major)
548 
549  // Y <- transposed(A) scaled(X).
550  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
551  BLAS_S, row_major)
552  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
553  BLAS_D, row_major)
554  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
555  BLAS_C, row_major)
556  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
557  BLAS_Z, row_major)
558 
559  // Y <- transposed(const A) scaled(X).
560  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
561  BLAS_S, row_major)
562  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
563  BLAS_D, row_major)
564  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
565  BLAS_C, row_major)
566  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
567  BLAS_Z, row_major)
568 
569  // Y <- conjugated(A) scaled(X).
570  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
571  BLAS_S, row_major)
572  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
573  BLAS_D, row_major)
574  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
575  BLAS_C, row_major)
576  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
577  BLAS_Z, row_major)
578 
579 
580  /* ********************************************************************* */
581  /* Rank one update. */
582  /* ********************************************************************* */
583 
584 # define ger_interface(blas_name, base_type) \
585  inline void rank_one_update(const dense_matrix<base_type > &A, \
586  const std::vector<base_type > &V, \
587  const std::vector<base_type > &W) { \
588  GMMLAPACK_TRACE("ger_interface"); \
589  long m(long(mat_nrows(A))), lda = m, n(long(mat_ncols(A))); \
590  long incx = 1, incy = 1; \
591  base_type alpha(1); \
592  if (m && n) \
593  blas_name(&m, &n, &alpha, &V[0], &incx, &W[0], &incy, &A(0,0), &lda);\
594  }
595 
596  ger_interface(sger_, BLAS_S)
597  ger_interface(dger_, BLAS_D)
598  ger_interface(cgerc_, BLAS_C)
599  ger_interface(zgerc_, BLAS_Z)
600 
601 # define ger_interface_sn(blas_name, base_type) \
602  inline void rank_one_update(const dense_matrix<base_type > &A, \
603  gemv_p2_s(base_type), \
604  const std::vector<base_type > &W) { \
605  GMMLAPACK_TRACE("ger_interface"); \
606  gemv_trans2_s(base_type); \
607  long m(long(mat_nrows(A))), lda = m, n(long(mat_ncols(A))); \
608  long incx = 1, incy = 1; \
609  if (m && n) \
610  blas_name(&m, &n, &alpha, &x[0], &incx, &W[0], &incy, &A(0,0), &lda);\
611  }
612 
613  ger_interface_sn(sger_, BLAS_S)
614  ger_interface_sn(dger_, BLAS_D)
615  ger_interface_sn(cgerc_, BLAS_C)
616  ger_interface_sn(zgerc_, BLAS_Z)
617 
618 # define ger_interface_ns(blas_name, base_type) \
619  inline void rank_one_update(const dense_matrix<base_type > &A, \
620  const std::vector<base_type > &V, \
621  gemv_p2_s(base_type)) { \
622  GMMLAPACK_TRACE("ger_interface"); \
623  gemv_trans2_s(base_type); \
624  long m(long(mat_nrows(A))), lda = m, n(long(mat_ncols(A))); \
625  long incx = 1, incy = 1; \
626  base_type al2 = gmm::conj(alpha); \
627  if (m && n) \
628  blas_name(&m, &n, &al2, &V[0], &incx, &x[0], &incy, &A(0,0), &lda); \
629  }
630 
631  ger_interface_ns(sger_, BLAS_S)
632  ger_interface_ns(dger_, BLAS_D)
633  ger_interface_ns(cgerc_, BLAS_C)
634  ger_interface_ns(zgerc_, BLAS_Z)
635 
636  /* ********************************************************************* */
637  /* dense matrix x dense matrix multiplication. */
638  /* ********************************************************************* */
639 
640 # define gemm_interface_nn(blas_name, base_type) \
641  inline void mult_spec(const dense_matrix<base_type > &A, \
642  const dense_matrix<base_type > &B, \
643  dense_matrix<base_type > &C, c_mult) { \
644  GMMLAPACK_TRACE("gemm_interface_nn"); \
645  const char t = 'N'; \
646  long m(long(mat_nrows(A))), lda = m, k(long(mat_ncols(A))); \
647  long n(long(mat_ncols(B))); \
648  long ldb = k, ldc = m; \
649  base_type alpha(1), beta(0); \
650  if (m && k && n) \
651  blas_name(&t, &t, &m, &n, &k, &alpha, \
652  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
653  else gmm::clear(C); \
654  }
655 
656  gemm_interface_nn(sgemm_, BLAS_S)
657  gemm_interface_nn(dgemm_, BLAS_D)
658  gemm_interface_nn(cgemm_, BLAS_C)
659  gemm_interface_nn(zgemm_, BLAS_Z)
660 
661  /* ********************************************************************* */
662  /* transposed(dense matrix) x dense matrix multiplication. */
663  /* ********************************************************************* */
664 
665 # define gemm_interface_tn(blas_name, base_type, is_const) \
666  inline void mult_spec( \
667  const transposed_col_ref<is_const<base_type > *> &A_,\
668  const dense_matrix<base_type > &B, \
669  dense_matrix<base_type > &C, rcmult) { \
670  GMMLAPACK_TRACE("gemm_interface_tn"); \
671  dense_matrix<base_type > &A \
672  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
673  const char t = 'T', u = 'N'; \
674  long m(long(mat_ncols(A))), k(long(mat_nrows(A))), n(long(mat_ncols(B))); \
675  long lda = k, ldb = k, ldc = m; \
676  base_type alpha(1), beta(0); \
677  if (m && k && n) \
678  blas_name(&t, &u, &m, &n, &k, &alpha, \
679  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
680  else gmm::clear(C); \
681  }
682 
683  gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
684  gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
685  gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
686  gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
687  gemm_interface_tn(sgemm_, BLAS_S, const dense_matrix)
688  gemm_interface_tn(dgemm_, BLAS_D, const dense_matrix)
689  gemm_interface_tn(cgemm_, BLAS_C, const dense_matrix)
690  gemm_interface_tn(zgemm_, BLAS_Z, const dense_matrix)
691 
692  /* ********************************************************************* */
693  /* dense matrix x transposed(dense matrix) multiplication. */
694  /* ********************************************************************* */
695 
696 # define gemm_interface_nt(blas_name, base_type, is_const) \
697  inline void mult_spec(const dense_matrix<base_type > &A, \
698  const transposed_col_ref<is_const<base_type > *> &B_, \
699  dense_matrix<base_type > &C, r_mult) { \
700  GMMLAPACK_TRACE("gemm_interface_nt"); \
701  dense_matrix<base_type > &B \
702  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
703  const char t = 'N', u = 'T'; \
704  long m(long(mat_nrows(A))), lda = m, k(long(mat_ncols(A))); \
705  long n(long(mat_nrows(B))); \
706  long ldb = n, ldc = m; \
707  base_type alpha(1), beta(0); \
708  if (m && k && n) \
709  blas_name(&t, &u, &m, &n, &k, &alpha, \
710  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
711  else gmm::clear(C); \
712  }
713 
714  gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
715  gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
716  gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
717  gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
718  gemm_interface_nt(sgemm_, BLAS_S, const dense_matrix)
719  gemm_interface_nt(dgemm_, BLAS_D, const dense_matrix)
720  gemm_interface_nt(cgemm_, BLAS_C, const dense_matrix)
721  gemm_interface_nt(zgemm_, BLAS_Z, const dense_matrix)
722 
723  /* ********************************************************************* */
724  /* transposed(dense matrix) x transposed(dense matrix) multiplication. */
725  /* ********************************************************************* */
726 
727 # define gemm_interface_tt(blas_name, base_type, isA_const, isB_const) \
728  inline void mult_spec( \
729  const transposed_col_ref<isA_const <base_type > *> &A_, \
730  const transposed_col_ref<isB_const <base_type > *> &B_, \
731  dense_matrix<base_type > &C, r_mult) { \
732  GMMLAPACK_TRACE("gemm_interface_tt"); \
733  dense_matrix<base_type > &A \
734  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
735  dense_matrix<base_type > &B \
736  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
737  const char t = 'T', u = 'T'; \
738  long m(long(mat_ncols(A))), k(long(mat_nrows(A))), n(long(mat_nrows(B))); \
739  long lda = k, ldb = n, ldc = m; \
740  base_type alpha(1), beta(0); \
741  if (m && k && n) \
742  blas_name(&t, &u, &m, &n, &k, &alpha, \
743  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
744  else gmm::clear(C); \
745  }
746 
747  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
748  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
749  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
750  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
751  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, dense_matrix)
752  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, dense_matrix)
753  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, dense_matrix)
754  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, dense_matrix)
755  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, const dense_matrix)
756  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, const dense_matrix)
757  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, const dense_matrix)
758  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, const dense_matrix)
759  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, const dense_matrix)
760  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, const dense_matrix)
761  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, const dense_matrix)
762  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, const dense_matrix)
763 
764 
765  /* ********************************************************************* */
766  /* conjugated(dense matrix) x dense matrix multiplication. */
767  /* ********************************************************************* */
768 
769 # define gemm_interface_cn(blas_name, base_type) \
770  inline void mult_spec( \
771  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\
772  const dense_matrix<base_type > &B, \
773  dense_matrix<base_type > &C, rcmult) { \
774  GMMLAPACK_TRACE("gemm_interface_cn"); \
775  dense_matrix<base_type > &A \
776  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
777  const char t = 'C', u = 'N'; \
778  long m(long(mat_ncols(A))), k(long(mat_nrows(A))), n(long(mat_ncols(B))); \
779  long lda = k, ldb = k, ldc = m; \
780  base_type alpha(1), beta(0); \
781  if (m && k && n) \
782  blas_name(&t, &u, &m, &n, &k, &alpha, \
783  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
784  else gmm::clear(C); \
785  }
786 
787  gemm_interface_cn(sgemm_, BLAS_S)
788  gemm_interface_cn(dgemm_, BLAS_D)
789  gemm_interface_cn(cgemm_, BLAS_C)
790  gemm_interface_cn(zgemm_, BLAS_Z)
791 
792  /* ********************************************************************* */
793  /* dense matrix x conjugated(dense matrix) multiplication. */
794  /* ********************************************************************* */
795 
796 # define gemm_interface_nc(blas_name, base_type) \
797  inline void mult_spec(const dense_matrix<base_type > &A, \
798  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\
799  dense_matrix<base_type > &C, c_mult, row_major) { \
800  GMMLAPACK_TRACE("gemm_interface_nc"); \
801  dense_matrix<base_type > &B \
802  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
803  const char t = 'N', u = 'C'; \
804  long m(long(mat_nrows(A))), lda = m, k(long(mat_ncols(A))); \
805  long n(long(mat_nrows(B))), ldb = n, ldc = m; \
806  base_type alpha(1), beta(0); \
807  if (m && k && n) \
808  blas_name(&t, &u, &m, &n, &k, &alpha, \
809  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
810  else gmm::clear(C); \
811  }
812 
813  gemm_interface_nc(sgemm_, BLAS_S)
814  gemm_interface_nc(dgemm_, BLAS_D)
815  gemm_interface_nc(cgemm_, BLAS_C)
816  gemm_interface_nc(zgemm_, BLAS_Z)
817 
818  /* ********************************************************************* */
819  /* conjugated(dense matrix) x conjugated(dense matrix) multiplication. */
820  /* ********************************************************************* */
821 
822 # define gemm_interface_cc(blas_name, base_type) \
823  inline void mult_spec( \
824  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\
825  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\
826  dense_matrix<base_type > &C, r_mult) { \
827  GMMLAPACK_TRACE("gemm_interface_cc"); \
828  dense_matrix<base_type > &A \
829  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
830  dense_matrix<base_type > &B \
831  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
832  const char t = 'C', u = 'C'; \
833  long m(long(mat_ncols(A))), k(long(mat_nrows(A))), lda = k; \
834  long n(long(mat_nrows(B))), ldb = n, ldc = m; \
835  base_type alpha(1), beta(0); \
836  if (m && k && n) \
837  blas_name(&t, &u, &m, &n, &k, &alpha, \
838  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
839  else gmm::clear(C); \
840  }
841 
842  gemm_interface_cc(sgemm_, BLAS_S)
843  gemm_interface_cc(dgemm_, BLAS_D)
844  gemm_interface_cc(cgemm_, BLAS_C)
845  gemm_interface_cc(zgemm_, BLAS_Z)
846 
847  /* ********************************************************************* */
848  /* Tri solve. */
849  /* ********************************************************************* */
850 
851 # define trsv_interface(f_name, loru, param1, trans1, blas_name, base_type)\
852  inline void f_name(param1(base_type), std::vector<base_type > &x, \
853  size_type k, bool is_unit) { \
854  GMMLAPACK_TRACE("trsv_interface"); \
855  loru; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
856  long lda(long(mat_nrows(A))), inc(1), n = long(k); \
857  if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
858  }
859 
860 # define trsv_upper const char l = 'U'
861 # define trsv_lower const char l = 'L'
862 
863  // X <- LOWER(A)^{-1}X.
864  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
865  strsv_, BLAS_S)
866  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
867  dtrsv_, BLAS_D)
868  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
869  ctrsv_, BLAS_C)
870  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
871  ztrsv_, BLAS_Z)
872 
873  // X <- UPPER(A)^{-1}X.
874  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
875  strsv_, BLAS_S)
876  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
877  dtrsv_, BLAS_D)
878  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
879  ctrsv_, BLAS_C)
880  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
881  ztrsv_, BLAS_Z)
882 
883  // X <- LOWER(transposed(A))^{-1}X.
884  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
885  strsv_, BLAS_S)
886  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
887  dtrsv_, BLAS_D)
888  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
889  ctrsv_, BLAS_C)
890  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
891  ztrsv_, BLAS_Z)
892 
893  // X <- UPPER(transposed(A))^{-1}X.
894  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
895  strsv_, BLAS_S)
896  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
897  dtrsv_, BLAS_D)
898  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
899  ctrsv_, BLAS_C)
900  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
901  ztrsv_, BLAS_Z)
902 
903  // X <- LOWER(transposed(const A))^{-1}X.
904  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
905  strsv_, BLAS_S)
906  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
907  dtrsv_, BLAS_D)
908  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
909  ctrsv_, BLAS_C)
910  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
911  ztrsv_, BLAS_Z)
912 
913  // X <- UPPER(transposed(const A))^{-1}X.
914  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
915  strsv_, BLAS_S)
916  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
917  dtrsv_, BLAS_D)
918  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
919  ctrsv_, BLAS_C)
920  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
921  ztrsv_, BLAS_Z)
922 
923  // X <- LOWER(conjugated(A))^{-1}X.
924  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
925  strsv_, BLAS_S)
926  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
927  dtrsv_, BLAS_D)
928  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
929  ctrsv_, BLAS_C)
930  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
931  ztrsv_, BLAS_Z)
932 
933  // X <- UPPER(conjugated(A))^{-1}X.
934  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
935  strsv_, BLAS_S)
936  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
937  dtrsv_, BLAS_D)
938  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
939  ctrsv_, BLAS_C)
940  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
941  ztrsv_, BLAS_Z)
942 
943 #endif
944 }
945 
946 #endif // GMM_BLAS_INTERFACE_H
947 
948 #endif // GMM_USES_BLAS
gmm interface for STL vectors.
Basic linear algebra functions.
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix, gmm::csc_matrix, etc.)