-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathksvd.py
executable file
·96 lines (81 loc) · 2.89 KB
/
ksvd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import scipy as sp
from sklearn.linear_model import orthogonal_mp_gram
from tqdm import tqdm
class ApproximateKSVD(object):
def __init__(self, n_components, max_iter=10, tol=1e-6, omp_tol=None,
transform_n_nonzero_coefs=None):
"""
This is a KSVD implementation based on: https://github.com/nel215/ksvd
Parameters
----------
n_components:
Number of dictionary elements
max_iter:
Maximum number of iterations
tol:
tolerance for error
transform_n_nonzero_coefs:
Number of nonzero coefficients to target
"""
self.components_ = None
self.gamma_ = None
self.max_iter = max_iter
self.tol = tol
self.omp_tol = omp_tol
self.n_components = n_components
self.transform_n_nonzero_coefs = transform_n_nonzero_coefs
def _update_dict(self, X, D, gamma):
if len(gamma.shape) == 1:
gamma = np.expand_dims(gamma, axis=0)
for j in range(self.n_components):
I = gamma[:, j] > 0
if np.sum(I) == 0:
continue
D[j, :] = 0
g = gamma[I, j].T
r = X[I, :] - gamma[I, :].dot(D)
d = r.T.dot(g)
d /= np.linalg.norm(d)
g = r.dot(d)
gamma[I, j] = g.T
D[j, :] = d
return D, gamma
def _initialize(self, X):
if min(X.shape) < self.n_components:
D = np.random.randn(self.n_components, X.shape[1])
else:
u, s, vt = sp.sparse.linalg.svds(X, k=self.n_components)
D = np.dot(np.diag(s), vt)
D /= np.linalg.norm(D, axis=1)[:, np.newaxis]
return D
def _transform(self, D, X):
gram = D.dot(D.T)
Xy = D.dot(X.T)
norms_squared = np.sum((X.T)**2, axis=0)
n_nonzero_coefs = self.transform_n_nonzero_coefs
if n_nonzero_coefs is None:
n_nonzero_coefs = int(0.1 * X.shape[1])
return orthogonal_mp_gram(
gram, Xy, n_nonzero_coefs=n_nonzero_coefs, tol=self.omp_tol, norms_squared=norms_squared).T
def fit(self, X, is_D_init=False, D_init=None):
"""
Parameters
----------
X: shape = [n_samples, n_features]
"""
if not is_D_init:
D = self._initialize(X)
else:
D = D_init
for i in tqdm(range(self.max_iter)):
gamma = self._transform(D, X)
e = np.linalg.norm(X - gamma.dot(D))
if e < self.tol:
break
D, gamma = self._update_dict(X, D, gamma)
self.components_ = D
self.gamma_ = gamma
return self
def transform(self, X):
return self._transform(self.components_, X)