Skip to content

Commit

Permalink
Update cublas interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jaharris87 committed Jun 28, 2024
1 parent f3ffdbd commit e2e1bb2
Showing 1 changed file with 108 additions and 6 deletions.
114 changes: 108 additions & 6 deletions source/cublasf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,48 @@ end function cublasGetMatrix
type(c_ptr), value :: stream
end function cublasGetMatrixAsync

integer(c_int) function &
& cublasDnrm2_v2(handle, n, dx, incx, xnorm) &
& bind(c, name="cublasDnrm2_v2")
use, intrinsic :: iso_c_binding
type(c_ptr), value :: handle
integer(c_int), value :: n
type(c_ptr), value :: dx
integer(c_int), value :: incx
real(c_double) :: xnorm
end function cublasDnrm2_v2

integer(c_int) function &
& cublasDaxpy_v2(handle, n, alpha, dx, incx, dy, incy) &
& bind(c, name="cublasDaxpy_v2")
use, intrinsic :: iso_c_binding
type(c_ptr), value :: handle
integer(c_int), value :: n
real(c_double), value :: alpha
type(c_ptr), value :: dx
integer(c_int), value :: incx
type(c_ptr), value :: dy
integer(c_int), value :: incy
end function cublasDaxpy_v2

integer(c_int) function &
& cublasDgemv_v2(handle, trans, m, n, alpha, dA, ldda, dx, incx, beta, dy, incy) &
& bind(c, name="cublasDgemv_v2")
use, intrinsic :: iso_c_binding
type(c_ptr), value :: handle
integer(c_int), value :: trans
integer(c_int), value :: m
integer(c_int), value :: n
real(c_double) :: alpha
type(c_ptr), value :: dA
integer(c_int), value :: ldda
type(c_ptr), value :: dx
integer(c_int), value :: incx
real(c_double) :: beta
type(c_ptr), value :: dy
integer(c_int), value :: incy
end function cublasDgemv_v2

integer(c_int) function &
& cublasDgetrfBatched(handle, n, dA, ldda, dP, dInfo, nbatch) &
& bind(c, name="cublasDgetrfBatched")
Expand Down Expand Up @@ -275,6 +317,31 @@ end function cublasDtrsmBatched
integer(c_int), value :: nbatch
end function cublasDgemmBatched

integer(c_int) function &
& cublasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, &
& dA, ldda, strideA, dB, lddb, strideB, beta, dC, lddc, strideC, nbatch) &
& bind(c, name="cublasDgemmStridedBatched")
use, intrinsic :: iso_c_binding
type(c_ptr), value :: handle
integer(c_int), value :: transa
integer(c_int), value :: transb
integer(c_int), value :: m
integer(c_int), value :: n
integer(c_int), value :: k
real(c_double) :: alpha
type(c_ptr), value :: dA
integer(c_int), value :: ldda
integer(c_int), value :: strideA
type(c_ptr), value :: dB
integer(c_int), value :: lddb
integer(c_int), value :: strideB
real(c_double) :: beta
type(c_ptr), value :: dC
integer(c_int), value :: strideC
integer(c_int), value :: lddc
integer(c_int), value :: nbatch
end function cublasDgemmStridedBatched

integer(c_int) function &
& cublasDtrsv(uplo, trans, diag, n, dA, ldda, dx, incx) &
& bind(c, name="cublasDtrsv")
Expand Down Expand Up @@ -305,11 +372,11 @@ end function cublasDtrsv
end function cublasDtrsv_v2

integer(c_int) function &
& cublasDtrsm(uplo, side, trans, diag, m, n, alpha, dA, ldda, dB, lddb) &
& cublasDtrsm(side, uplo, trans, diag, m, n, alpha, dA, ldda, dB, lddb) &
& bind(c, name="cublasDtrsm")
use, intrinsic :: iso_c_binding
character(c_char), value :: uplo
character(c_char), value :: side
character(c_char), value :: uplo
character(c_char), value :: trans
character(c_char), value :: diag
integer(c_int), value :: m
Expand All @@ -322,12 +389,12 @@ end function cublasDtrsv_v2
end function cublasDtrsm

integer(c_int) function &
& cublasDtrsm_v2(handle, uplo, side, trans, diag, m, n, alpha, dA, ldda, dB, lddb) &
& cublasDtrsm_v2(handle, side, uplo, trans, diag, m, n, alpha, dA, ldda, dB, lddb) &
& bind(c, name="cublasDtrsm_v2")
use, intrinsic :: iso_c_binding
type(c_ptr), value :: handle
integer(c_int), value :: uplo
integer(c_int), value :: side
integer(c_int), value :: uplo
integer(c_int), value :: trans
integer(c_int), value :: diag
integer(c_int), value :: m
Expand Down Expand Up @@ -368,12 +435,12 @@ end function cublasDgemm
integer(c_int), value :: m
integer(c_int), value :: n
integer(c_int), value :: k
real(c_double), value :: alpha
real(c_double) :: alpha
type(c_ptr), value :: dA
integer(c_int), value :: ldda
type(c_ptr), value :: dB
integer(c_int), value :: lddb
real(c_double), value :: beta
real(c_double) :: beta
type(c_ptr), value :: dC
integer(c_int), value :: lddc
end function cublasDgemm_v2
Expand All @@ -395,6 +462,41 @@ end function cublasDgemm_v2
integer(c_int), value :: nbatch
end function cublasDgetrsBatched

integer(c_int) function &
& cublasDgeam(handle, transa, transb, m, n, alpha, dA, ldda, beta, dB, lddb, dC, lddc) &
& bind(c, name="cublasDgeam")
use, intrinsic :: iso_c_binding
type(c_ptr), value :: handle
integer(c_int), value :: transa
integer(c_int), value :: transb
integer(c_int), value :: m
integer(c_int), value :: n
real(c_double) :: alpha
type(c_ptr), value :: dA
integer(c_int), value :: ldda
real(c_double) :: beta
type(c_ptr), value :: dB
integer(c_int), value :: lddb
type(c_ptr), value :: dC
integer(c_int), value :: lddc
end function cublasDgeam

integer(c_int) function &
& cublasDdgmm(handle, mode, m, n, dA, ldda, dx, incx, dC, lddc) &
& bind(c, name="cublasDdgmm")
use, intrinsic :: iso_c_binding
type(c_ptr), value :: handle
integer(c_int), value :: mode
integer(c_int), value :: m
integer(c_int), value :: n
type(c_ptr), value :: dA
integer(c_int), value :: ldda
type(c_ptr), value :: dx
integer(c_int), value :: incx
type(c_ptr), value :: dC
integer(c_int), value :: lddc
end function cublasDdgmm

end interface

end module cublasf

0 comments on commit e2e1bb2

Please sign in to comment.