From e2e1bb22d217ff841963de864fdebe5760845e31 Mon Sep 17 00:00:00 2001 From: Austin Harris Date: Mon, 13 Jan 2020 11:33:30 -0500 Subject: [PATCH] Update cublas interface --- source/cublasf.f90 | 114 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 108 insertions(+), 6 deletions(-) diff --git a/source/cublasf.f90 b/source/cublasf.f90 index fbc02d11..52980745 100644 --- a/source/cublasf.f90 +++ b/source/cublasf.f90 @@ -207,6 +207,48 @@ end function cublasGetMatrix type(c_ptr), value :: stream end function cublasGetMatrixAsync + integer(c_int) function & + & cublasDnrm2_v2(handle, n, dx, incx, xnorm) & + & bind(c, name="cublasDnrm2_v2") + use, intrinsic :: iso_c_binding + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: dx + integer(c_int), value :: incx + real(c_double) :: xnorm + end function cublasDnrm2_v2 + + integer(c_int) function & + & cublasDaxpy_v2(handle, n, alpha, dx, incx, dy, incy) & + & bind(c, name="cublasDaxpy_v2") + use, intrinsic :: iso_c_binding + type(c_ptr), value :: handle + integer(c_int), value :: n + real(c_double), value :: alpha + type(c_ptr), value :: dx + integer(c_int), value :: incx + type(c_ptr), value :: dy + integer(c_int), value :: incy + end function cublasDaxpy_v2 + + integer(c_int) function & + & cublasDgemv_v2(handle, trans, m, n, alpha, dA, ldda, dx, incx, beta, dy, incy) & + & bind(c, name="cublasDgemv_v2") + use, intrinsic :: iso_c_binding + type(c_ptr), value :: handle + integer(c_int), value :: trans + integer(c_int), value :: m + integer(c_int), value :: n + real(c_double) :: alpha + type(c_ptr), value :: dA + integer(c_int), value :: ldda + type(c_ptr), value :: dx + integer(c_int), value :: incx + real(c_double) :: beta + type(c_ptr), value :: dy + integer(c_int), value :: incy + end function cublasDgemv_v2 + integer(c_int) function & & cublasDgetrfBatched(handle, n, dA, ldda, dP, dInfo, nbatch) & & bind(c, name="cublasDgetrfBatched") @@ -275,6 +317,31 @@ end function cublasDtrsmBatched integer(c_int), value :: nbatch end function cublasDgemmBatched + integer(c_int) function & + & cublasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, & + & dA, ldda, strideA, dB, lddb, strideB, beta, dC, lddc, strideC, nbatch) & + & bind(c, name="cublasDgemmStridedBatched") + use, intrinsic :: iso_c_binding + type(c_ptr), value :: handle + integer(c_int), value :: transa + integer(c_int), value :: transb + integer(c_int), value :: m + integer(c_int), value :: n + integer(c_int), value :: k + real(c_double) :: alpha + type(c_ptr), value :: dA + integer(c_int), value :: ldda + integer(c_int), value :: strideA + type(c_ptr), value :: dB + integer(c_int), value :: lddb + integer(c_int), value :: strideB + real(c_double) :: beta + type(c_ptr), value :: dC + integer(c_int), value :: strideC + integer(c_int), value :: lddc + integer(c_int), value :: nbatch + end function cublasDgemmStridedBatched + integer(c_int) function & & cublasDtrsv(uplo, trans, diag, n, dA, ldda, dx, incx) & & bind(c, name="cublasDtrsv") @@ -305,11 +372,11 @@ end function cublasDtrsv end function cublasDtrsv_v2 integer(c_int) function & - & cublasDtrsm(uplo, side, trans, diag, m, n, alpha, dA, ldda, dB, lddb) & + & cublasDtrsm(side, uplo, trans, diag, m, n, alpha, dA, ldda, dB, lddb) & & bind(c, name="cublasDtrsm") use, intrinsic :: iso_c_binding - character(c_char), value :: uplo character(c_char), value :: side + character(c_char), value :: uplo character(c_char), value :: trans character(c_char), value :: diag integer(c_int), value :: m @@ -322,12 +389,12 @@ end function cublasDtrsv_v2 end function cublasDtrsm integer(c_int) function & - & cublasDtrsm_v2(handle, uplo, side, trans, diag, m, n, alpha, dA, ldda, dB, lddb) & + & cublasDtrsm_v2(handle, side, uplo, trans, diag, m, n, alpha, dA, ldda, dB, lddb) & & bind(c, name="cublasDtrsm_v2") use, intrinsic :: iso_c_binding type(c_ptr), value :: handle - integer(c_int), value :: uplo integer(c_int), value :: side + integer(c_int), value :: uplo integer(c_int), value :: trans integer(c_int), value :: diag integer(c_int), value :: m @@ -368,12 +435,12 @@ end function cublasDgemm integer(c_int), value :: m integer(c_int), value :: n integer(c_int), value :: k - real(c_double), value :: alpha + real(c_double) :: alpha type(c_ptr), value :: dA integer(c_int), value :: ldda type(c_ptr), value :: dB integer(c_int), value :: lddb - real(c_double), value :: beta + real(c_double) :: beta type(c_ptr), value :: dC integer(c_int), value :: lddc end function cublasDgemm_v2 @@ -395,6 +462,41 @@ end function cublasDgemm_v2 integer(c_int), value :: nbatch end function cublasDgetrsBatched + integer(c_int) function & + & cublasDgeam(handle, transa, transb, m, n, alpha, dA, ldda, beta, dB, lddb, dC, lddc) & + & bind(c, name="cublasDgeam") + use, intrinsic :: iso_c_binding + type(c_ptr), value :: handle + integer(c_int), value :: transa + integer(c_int), value :: transb + integer(c_int), value :: m + integer(c_int), value :: n + real(c_double) :: alpha + type(c_ptr), value :: dA + integer(c_int), value :: ldda + real(c_double) :: beta + type(c_ptr), value :: dB + integer(c_int), value :: lddb + type(c_ptr), value :: dC + integer(c_int), value :: lddc + end function cublasDgeam + + integer(c_int) function & + & cublasDdgmm(handle, mode, m, n, dA, ldda, dx, incx, dC, lddc) & + & bind(c, name="cublasDdgmm") + use, intrinsic :: iso_c_binding + type(c_ptr), value :: handle + integer(c_int), value :: mode + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: dA + integer(c_int), value :: ldda + type(c_ptr), value :: dx + integer(c_int), value :: incx + type(c_ptr), value :: dC + integer(c_int), value :: lddc + end function cublasDdgmm + end interface end module cublasf