-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathwrapper.pyx
46 lines (34 loc) · 1.85 KB
/
wrapper.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
cimport numpy as np
#assert sizeof(int) == sizeof(np.int32_t)
#assert sizeof(np.float32) == sizeof(np.float32_t)
cdef extern from "Solver_manager.hh":
cdef cppclass C_DnSolver "DnSolver":
C_DnSolver(np.int32_t, np.int32_t)
void from_dense(np.float32_t*, np.float32_t*)
void from_csr(np.int32_t*, np.int32_t*, np.float32_t*, np.float32_t*)
void from_coo(np.int32_t*, np.int32_t*, np.float32_t*, np.int32_t, np.float32_t*)
void solve(np.int32_t, np.int32_t)
void solve_Axb(np.int32_t)
void retrieve_to(np.float32_t*)
cdef class DnSolver:
cdef C_DnSolver* g
cdef int rows
cdef int cols
def __cinit__(self, np.int32_t rows, np.int32_t cols):
self.rows, self.cols = rows, cols
self.g = new C_DnSolver( self.rows, self.cols)
def from_dense(self, np.ndarray[ndim=1, dtype=np.float32_t] arr, np.ndarray[ndim=1,dtype=np.float32_t] rhs):
self.g.from_dense(&arr[0], &rhs[0])
def from_csr(self, np.ndarray[ndim=1, dtype=np.int32_t] indptr, np.ndarray[ndim=1, dtype=np.int32_t] indices, np.ndarray[ndim=1,dtype=np.float32_t] data, np.ndarray[ndim=1,dtype=np.float32_t] rhs):
self.g.from_csr(&indptr[0], &indices[0], &data[0], &rhs[0])
def from_coo(self, np.ndarray[ndim=1, dtype=np.int32_t] indptr, np.ndarray[ndim=1, dtype=np.int32_t] indices, np.ndarray[ndim=1,dtype=np.float32_t] data, np.int32_t nnz, np.ndarray[ndim=1,dtype=np.float32_t] rhs):
self.g.from_coo(&indptr[0], &indices[0], &data[0], nnz, &rhs[0])
def solve(self, np.int32_t multFunc, np.int32_t func):
self.g.solve(multFunc, func)
def solve_Axb(self, np.int32_t func):
self.g.solve_Axb(func)
def retrieve(self):
cdef np.ndarray[ndim=1, dtype=np.float32_t] x = np.zeros(self.cols, dtype=np.float32)
self.g.retrieve_to(&x[0])
return x