38 #if defined(GETFEM_USES_BLAS) || defined(GMM_USES_BLAS) \ 39 || defined(GMM_USES_LAPACK) || defined(GMM_USES_ATLAS) 41 #ifndef GMM_BLAS_INTERFACE_H 42 #define GMM_BLAS_INTERFACE_H 52 #define GMMLAPACK_TRACE(f) 145 # define BLAS_S float 146 # define BLAS_D double 147 # define BLAS_C std::complex<float> 148 # define BLAS_Z std::complex<double> 154 void daxpy_(
const long *n,
const double *alpha,
const double *x,
155 const long *incx,
double *y,
const long *incy);
156 void dgemm_(
const char *tA,
const char *tB,
const long *m,
157 const long *n,
const long *k,
const double *alpha,
158 const double *A,
const long *ldA,
const double *B,
159 const long *ldB,
const double *beta,
double *C,
161 void sgemm_(...);
void cgemm_(...);
void zgemm_(...);
162 void sgemv_(...);
void dgemv_(...);
void cgemv_(...);
void zgemv_(...);
163 void strsv_(...);
void dtrsv_(...);
void ctrsv_(...);
void ztrsv_(...);
164 void saxpy_(...);
void caxpy_(...);
void zaxpy_(...);
165 BLAS_S sdot_ (...); BLAS_D ddot_ (...);
166 BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
167 BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
168 BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
169 BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
170 void sger_(...);
void dger_(...);
void cgerc_(...);
void zgerc_(...);
179 # define nrm2_interface(param1, trans1, blas_name, base_type) \ 180 inline number_traits<base_type >::magnitude_type \ 181 vect_norm2(param1(base_type)) { \ 182 GMMLAPACK_TRACE("nrm2_interface"); \ 183 long inc(1), n(long(vect_size(x))); trans1(base_type); \ 184 return blas_name(&n, &x[0], &inc); \ 187 # define nrm2_p1(base_type) const std::vector<base_type > &x 188 # define nrm2_trans1(base_type) 190 nrm2_interface(nrm2_p1, nrm2_trans1, snrm2_ , BLAS_S)
191 nrm2_interface(nrm2_p1, nrm2_trans1, dnrm2_ , BLAS_D)
192 nrm2_interface(nrm2_p1, nrm2_trans1, scnrm2_, BLAS_C)
193 nrm2_interface(nrm2_p1, nrm2_trans1, dznrm2_, BLAS_Z)
199 # define dot_interface(param1, trans1, mult1, param2, trans2, mult2, \ 200 blas_name, base_type) \ 201 inline base_type vect_sp(param1(base_type), param2(base_type)) { \ 202 GMMLAPACK_TRACE("dot_interface"); \ 203 trans1(base_type); trans2(base_type); long inc(1), n(long(vect_size(y)));\ 204 return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \ 207 # define dot_p1(base_type) const std::vector<base_type > &x 208 # define dot_trans1(base_type) 209 # define dot_p1_s(base_type) \ 210 const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_ 211 # define dot_trans1_s(base_type) \ 212 std::vector<base_type > &x = \ 213 const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \ 216 # define dot_p2(base_type) const std::vector<base_type > &y 217 # define dot_trans2(base_type) 218 # define dot_p2_s(base_type) \ 219 const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_ 220 # define dot_trans2_s(base_type) \ 221 std::vector<base_type > &y = \ 222 const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \ 225 dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2, dot_trans2, (BLAS_S),
227 dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2, dot_trans2, (BLAS_D),
229 dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2, dot_trans2, (BLAS_C),
231 dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2, dot_trans2, (BLAS_Z),
234 dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_S),
236 dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_D),
238 dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_C),
240 dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_Z),
243 dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2_s, dot_trans2_s, b*,
245 dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2_s, dot_trans2_s, b*,
247 dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2_s, dot_trans2_s, b*,
249 dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2_s, dot_trans2_s, b*,
252 dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,sdot_ ,
254 dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,ddot_ ,
256 dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,cdotu_,
258 dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,zdotu_,
266 # define dotc_interface(param1, trans1, mult1, param2, trans2, mult2, \ 267 blas_name, base_type) \ 268 inline base_type vect_hp(param1(base_type), param2(base_type)) { \ 269 GMMLAPACK_TRACE("dotc_interface"); \ 270 trans1(base_type); trans2(base_type); long inc(1), n(long(vect_size(y)));\ 271 return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \ 274 # define dotc_p1(base_type) const std::vector<base_type > &x 275 # define dotc_trans1(base_type) 276 # define dotc_p1_s(base_type) \ 277 const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_ 278 # define dotc_trans1_s(base_type) \ 279 std::vector<base_type > &x = \ 280 const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \ 283 # define dotc_p2(base_type) const std::vector<base_type > &y 284 # define dotc_trans2(base_type) 285 # define dotc_p2_s(base_type) \ 286 const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_ 287 # define dotc_trans2_s(base_type) \ 288 std::vector<base_type > &y = \ 289 const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \ 290 base_type b(gmm::conj(y_.r)) 292 dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2, dotc_trans2,
293 (BLAS_S),sdot_ ,BLAS_S)
294 dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2, dotc_trans2,
295 (BLAS_D),ddot_ ,BLAS_D)
296 dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2, dotc_trans2,
297 (BLAS_C),cdotc_,BLAS_C)
298 dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2, dotc_trans2,
299 (BLAS_Z),zdotc_,BLAS_Z)
301 dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
302 (BLAS_S),sdot_, BLAS_S)
303 dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
304 (BLAS_D),ddot_ , BLAS_D)
305 dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
306 (BLAS_C),cdotc_, BLAS_C)
307 dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
308 (BLAS_Z),zdotc_, BLAS_Z)
310 dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2_s, dotc_trans2_s,
312 dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2_s, dotc_trans2_s,
314 dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2_s, dotc_trans2_s,
316 dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2_s, dotc_trans2_s,
319 dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,sdot_ ,
321 dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,ddot_ ,
323 dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,cdotc_,
325 dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,zdotc_,
332 # define axpy_interface(param1, trans1, blas_name, base_type) \ 333 inline void add(param1(base_type), std::vector<base_type > &y) { \ 334 GMMLAPACK_TRACE("axpy_interface"); \ 335 long inc(1), n(long(vect_size(y))); trans1(base_type); \ 336 if (n == 0) return; \ 337 blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \ 340 # define axpy_p1(base_type) const std::vector<base_type > &x 341 # define axpy_trans1(base_type) base_type a(1) 342 # define axpy_p1_s(base_type) \ 343 const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_ 344 # define axpy_trans1_s(base_type) \ 345 std::vector<base_type > &x = \ 346 const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \ 349 axpy_interface(axpy_p1, axpy_trans1, saxpy_, BLAS_S)
350 axpy_interface(axpy_p1, axpy_trans1, daxpy_, BLAS_D)
351 axpy_interface(axpy_p1, axpy_trans1, caxpy_, BLAS_C)
352 axpy_interface(axpy_p1, axpy_trans1, zaxpy_, BLAS_Z)
354 axpy_interface(axpy_p1_s, axpy_trans1_s, saxpy_, BLAS_S)
355 axpy_interface(axpy_p1_s, axpy_trans1_s, daxpy_, BLAS_D)
356 axpy_interface(axpy_p1_s, axpy_trans1_s, caxpy_, BLAS_C)
357 axpy_interface(axpy_p1_s, axpy_trans1_s, zaxpy_, BLAS_Z)
364 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \ 366 inline void mult_add_spec(param1(base_type), param2(base_type), \ 367 std::vector<base_type > &z, orien) { \ 368 GMMLAPACK_TRACE("gemv_interface"); \ 369 trans1(base_type); trans2(base_type); base_type beta(1); \ 370 long m(long(mat_nrows(A))), lda(m), n(long(mat_ncols(A))), inc(1); \ 371 if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \ 372 &beta, &z[0], &inc); \ 373 else gmm::clear(z); \ 377 # define gem_p1_n(base_type) const dense_matrix<base_type > &A 378 # define gem_trans1_n(base_type) const char t = 'N' 379 # define gem_p1_t(base_type) \ 380 const transposed_col_ref<dense_matrix<base_type > *> &A_ 381 # define gem_trans1_t(base_type) dense_matrix<base_type > &A = \ 382 const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \ 384 # define gem_p1_tc(base_type) \ 385 const transposed_col_ref<const dense_matrix<base_type > *> &A_ 386 # define gem_p1_c(base_type) \ 387 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_ 388 # define gem_trans1_c(base_type) dense_matrix<base_type > &A = \ 389 const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \ 393 # define gemv_p2_n(base_type) const std::vector<base_type > &x 394 # define gemv_trans2_n(base_type) base_type alpha(1) 395 # define gemv_p2_s(base_type) \ 396 const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_ 397 # define gemv_trans2_s(base_type) std::vector<base_type > &x = \ 398 const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \ 399 base_type alpha(x_.r) 402 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
404 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
406 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
408 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
412 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
414 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
416 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
418 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
422 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
424 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
426 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
428 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
432 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
434 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
436 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
438 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
442 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
444 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
446 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
448 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
452 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
454 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
456 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
458 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
462 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
464 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
466 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
468 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
472 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
474 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
476 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
478 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
486 # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \ 488 inline void mult_spec(param1(base_type), param2(base_type), \ 489 std::vector<base_type > &z, orien) { \ 490 GMMLAPACK_TRACE("gemv_interface2"); \ 491 trans1(base_type); trans2(base_type); base_type beta(0); \ 492 long m(long(mat_nrows(A))), lda(m), n(long(mat_ncols(A))), inc(1); \ 494 blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \ 496 else gmm::clear(z); \ 500 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
502 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
504 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
506 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
510 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
512 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
514 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
516 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
520 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
522 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
524 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
526 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
530 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
532 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
534 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
536 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
540 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
542 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
544 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
546 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
550 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
552 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
554 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
556 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
560 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
562 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
564 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
566 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
570 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
572 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
574 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
576 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
584 # define ger_interface(blas_name, base_type) \ 585 inline void rank_one_update(const dense_matrix<base_type > &A, \ 586 const std::vector<base_type > &V, \ 587 const std::vector<base_type > &W) { \ 588 GMMLAPACK_TRACE("ger_interface"); \ 589 long m(long(mat_nrows(A))), lda = m, n(long(mat_ncols(A))); \ 590 long incx = 1, incy = 1; \ 591 base_type alpha(1); \ 593 blas_name(&m, &n, &alpha, &V[0], &incx, &W[0], &incy, &A(0,0), &lda);\ 596 ger_interface(sger_, BLAS_S)
597 ger_interface(dger_, BLAS_D)
598 ger_interface(cgerc_, BLAS_C)
599 ger_interface(zgerc_, BLAS_Z)
601 # define ger_interface_sn(blas_name, base_type) \ 602 inline void rank_one_update(const dense_matrix<base_type > &A, \ 603 gemv_p2_s(base_type), \ 604 const std::vector<base_type > &W) { \ 605 GMMLAPACK_TRACE("ger_interface"); \ 606 gemv_trans2_s(base_type); \ 607 long m(long(mat_nrows(A))), lda = m, n(long(mat_ncols(A))); \ 608 long incx = 1, incy = 1; \ 610 blas_name(&m, &n, &alpha, &x[0], &incx, &W[0], &incy, &A(0,0), &lda);\ 613 ger_interface_sn(sger_, BLAS_S)
614 ger_interface_sn(dger_, BLAS_D)
615 ger_interface_sn(cgerc_, BLAS_C)
616 ger_interface_sn(zgerc_, BLAS_Z)
618 # define ger_interface_ns(blas_name, base_type) \ 619 inline void rank_one_update(const dense_matrix<base_type > &A, \ 620 const std::vector<base_type > &V, \ 621 gemv_p2_s(base_type)) { \ 622 GMMLAPACK_TRACE("ger_interface"); \ 623 gemv_trans2_s(base_type); \ 624 long m(long(mat_nrows(A))), lda = m, n(long(mat_ncols(A))); \ 625 long incx = 1, incy = 1; \ 626 base_type al2 = gmm::conj(alpha); \ 628 blas_name(&m, &n, &al2, &V[0], &incx, &x[0], &incy, &A(0,0), &lda); \ 631 ger_interface_ns(sger_, BLAS_S)
632 ger_interface_ns(dger_, BLAS_D)
633 ger_interface_ns(cgerc_, BLAS_C)
634 ger_interface_ns(zgerc_, BLAS_Z)
640 # define gemm_interface_nn(blas_name, base_type) \ 641 inline void mult_spec(const dense_matrix<base_type > &A, \ 642 const dense_matrix<base_type > &B, \ 643 dense_matrix<base_type > &C, c_mult) { \ 644 GMMLAPACK_TRACE("gemm_interface_nn"); \ 645 const char t = 'N'; \ 646 long m(long(mat_nrows(A))), lda = m, k(long(mat_ncols(A))); \ 647 long n(long(mat_ncols(B))); \ 648 long ldb = k, ldc = m; \ 649 base_type alpha(1), beta(0); \ 651 blas_name(&t, &t, &m, &n, &k, &alpha, \ 652 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \ 653 else gmm::clear(C); \ 656 gemm_interface_nn(sgemm_, BLAS_S)
657 gemm_interface_nn(dgemm_, BLAS_D)
658 gemm_interface_nn(cgemm_, BLAS_C)
659 gemm_interface_nn(zgemm_, BLAS_Z)
665 # define gemm_interface_tn(blas_name, base_type, is_const) \ 666 inline void mult_spec( \ 667 const transposed_col_ref<is_const<base_type > *> &A_,\ 668 const dense_matrix<base_type > &B, \ 669 dense_matrix<base_type > &C, rcmult) { \ 670 GMMLAPACK_TRACE("gemm_interface_tn"); \ 671 dense_matrix<base_type > &A \ 672 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \ 673 const char t = 'T', u = 'N'; \ 674 long m(long(mat_ncols(A))), k(long(mat_nrows(A))), n(long(mat_ncols(B))); \ 675 long lda = k, ldb = k, ldc = m; \ 676 base_type alpha(1), beta(0); \ 678 blas_name(&t, &u, &m, &n, &k, &alpha, \ 679 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \ 680 else gmm::clear(C); \ 683 gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
684 gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
685 gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
686 gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
687 gemm_interface_tn(sgemm_, BLAS_S, const dense_matrix)
688 gemm_interface_tn(dgemm_, BLAS_D, const dense_matrix)
689 gemm_interface_tn(cgemm_, BLAS_C, const dense_matrix)
690 gemm_interface_tn(zgemm_, BLAS_Z, const dense_matrix)
696 # define gemm_interface_nt(blas_name, base_type, is_const) \ 697 inline void mult_spec(const dense_matrix<base_type > &A, \ 698 const transposed_col_ref<is_const<base_type > *> &B_, \ 699 dense_matrix<base_type > &C, r_mult) { \ 700 GMMLAPACK_TRACE("gemm_interface_nt"); \ 701 dense_matrix<base_type > &B \ 702 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \ 703 const char t = 'N', u = 'T'; \ 704 long m(long(mat_nrows(A))), lda = m, k(long(mat_ncols(A))); \ 705 long n(long(mat_nrows(B))); \ 706 long ldb = n, ldc = m; \ 707 base_type alpha(1), beta(0); \ 709 blas_name(&t, &u, &m, &n, &k, &alpha, \ 710 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \ 711 else gmm::clear(C); \ 714 gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
715 gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
716 gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
717 gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
718 gemm_interface_nt(sgemm_, BLAS_S, const dense_matrix)
719 gemm_interface_nt(dgemm_, BLAS_D, const dense_matrix)
720 gemm_interface_nt(cgemm_, BLAS_C, const dense_matrix)
721 gemm_interface_nt(zgemm_, BLAS_Z, const dense_matrix)
727 # define gemm_interface_tt(blas_name, base_type, isA_const, isB_const) \ 728 inline void mult_spec( \ 729 const transposed_col_ref<isA_const <base_type > *> &A_, \ 730 const transposed_col_ref<isB_const <base_type > *> &B_, \ 731 dense_matrix<base_type > &C, r_mult) { \ 732 GMMLAPACK_TRACE("gemm_interface_tt"); \ 733 dense_matrix<base_type > &A \ 734 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \ 735 dense_matrix<base_type > &B \ 736 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \ 737 const char t = 'T', u = 'T'; \ 738 long m(long(mat_ncols(A))), k(long(mat_nrows(A))), n(long(mat_nrows(B))); \ 739 long lda = k, ldb = n, ldc = m; \ 740 base_type alpha(1), beta(0); \ 742 blas_name(&t, &u, &m, &n, &k, &alpha, \ 743 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \ 744 else gmm::clear(C); \ 747 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
748 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
749 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
750 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
751 gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, dense_matrix)
752 gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, dense_matrix)
753 gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, dense_matrix)
754 gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, dense_matrix)
755 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, const dense_matrix)
756 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, const dense_matrix)
757 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, const dense_matrix)
758 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, const dense_matrix)
759 gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, const dense_matrix)
760 gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, const dense_matrix)
761 gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, const dense_matrix)
762 gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, const dense_matrix)
769 # define gemm_interface_cn(blas_name, base_type) \ 770 inline void mult_spec( \ 771 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\ 772 const dense_matrix<base_type > &B, \ 773 dense_matrix<base_type > &C, rcmult) { \ 774 GMMLAPACK_TRACE("gemm_interface_cn"); \ 775 dense_matrix<base_type > &A \ 776 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \ 777 const char t = 'C', u = 'N'; \ 778 long m(long(mat_ncols(A))), k(long(mat_nrows(A))), n(long(mat_ncols(B))); \ 779 long lda = k, ldb = k, ldc = m; \ 780 base_type alpha(1), beta(0); \ 782 blas_name(&t, &u, &m, &n, &k, &alpha, \ 783 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \ 784 else gmm::clear(C); \ 787 gemm_interface_cn(sgemm_, BLAS_S)
788 gemm_interface_cn(dgemm_, BLAS_D)
789 gemm_interface_cn(cgemm_, BLAS_C)
790 gemm_interface_cn(zgemm_, BLAS_Z)
796 # define gemm_interface_nc(blas_name, base_type) \ 797 inline void mult_spec(const dense_matrix<base_type > &A, \ 798 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\ 799 dense_matrix<base_type > &C, c_mult, row_major) { \ 800 GMMLAPACK_TRACE("gemm_interface_nc"); \ 801 dense_matrix<base_type > &B \ 802 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \ 803 const char t = 'N', u = 'C'; \ 804 long m(long(mat_nrows(A))), lda = m, k(long(mat_ncols(A))); \ 805 long n(long(mat_nrows(B))), ldb = n, ldc = m; \ 806 base_type alpha(1), beta(0); \ 808 blas_name(&t, &u, &m, &n, &k, &alpha, \ 809 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \ 810 else gmm::clear(C); \ 813 gemm_interface_nc(sgemm_, BLAS_S)
814 gemm_interface_nc(dgemm_, BLAS_D)
815 gemm_interface_nc(cgemm_, BLAS_C)
816 gemm_interface_nc(zgemm_, BLAS_Z)
822 # define gemm_interface_cc(blas_name, base_type) \ 823 inline void mult_spec( \ 824 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\ 825 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\ 826 dense_matrix<base_type > &C, r_mult) { \ 827 GMMLAPACK_TRACE("gemm_interface_cc"); \ 828 dense_matrix<base_type > &A \ 829 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \ 830 dense_matrix<base_type > &B \ 831 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \ 832 const char t = 'C', u = 'C'; \ 833 long m(long(mat_ncols(A))), k(long(mat_nrows(A))), lda = k; \ 834 long n(long(mat_nrows(B))), ldb = n, ldc = m; \ 835 base_type alpha(1), beta(0); \ 837 blas_name(&t, &u, &m, &n, &k, &alpha, \ 838 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \ 839 else gmm::clear(C); \ 842 gemm_interface_cc(sgemm_, BLAS_S)
843 gemm_interface_cc(dgemm_, BLAS_D)
844 gemm_interface_cc(cgemm_, BLAS_C)
845 gemm_interface_cc(zgemm_, BLAS_Z)
851 # define trsv_interface(f_name, loru, param1, trans1, blas_name, base_type)\ 852 inline void f_name(param1(base_type), std::vector<base_type > &x, \ 853 size_type k, bool is_unit) { \ 854 GMMLAPACK_TRACE("trsv_interface"); \ 855 loru; trans1(base_type); char d = is_unit ? 'U' : 'N'; \ 856 long lda(long(mat_nrows(A))), inc(1), n = long(k); \ 857 if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \ 860 # define trsv_upper const char l = 'U' 861 # define trsv_lower const char l = 'L' 864 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
866 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
868 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
870 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
874 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
876 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
878 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
880 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
884 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
886 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
888 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
890 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
894 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
896 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
898 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
900 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
904 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
906 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
908 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
910 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
914 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
916 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
918 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
920 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
924 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
926 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
928 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
930 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
934 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
936 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
938 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
940 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
946 #endif // GMM_BLAS_INTERFACE_H 948 #endif // GMM_USES_BLAS gmm interface for STL vectors.
Basic linear algebra functions.
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix, gmm::csc_matrix, etc.)