-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathvbgmm.pyx
45 lines (30 loc) · 1.24 KB
/
vbgmm.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
vbgmm.pyx
simple cython wrapper for variational Gaussian mixture model in C
"""
import cython
# import both numpy and the Cython declarations for numpy
import numpy as np
cimport numpy as np
# declare the interface to the C code
cdef extern void c_vbgmm_fit (double* adX, int nN, int nD, int nK, int seed, int* anAssign, int nThreads, int nIter)
@cython.boundscheck(False)
@cython.wraparound(False)
def fit(np.ndarray[double, ndim=2, mode="c"] xarray not None, nClusters, seed, threads, piter):
"""
fit (xarray, nClusters, seed, threads)
Takes a numpy array xarray as input, fits the vbgmm using nClsuters initial clusters
param: xarray -- a 2-d numpy array of np.float64
param: nClusters -- an int, number of start clusters
param: seed -- an int, the random seed
param: threads -- int, the number of threads to use
param: piter -- int, the number of VB iterations to use
"""
cdef int nN, nD, nK, nThreads, nIter
nN, nD = xarray.shape[0], xarray.shape[1]
nK = nClusters
nIter = piter
nThreads = threads
cdef np.ndarray[int, ndim=1,mode="c"] assign = np.zeros((nN), dtype=np.intc)
c_vbgmm_fit (&xarray[0,0], nN, nD, nK, seed, &assign[0], nThreads, nIter)
return assign