-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcs_algos.py
66 lines (44 loc) · 1.25 KB
/
cs_algos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
'''
ALGORITHMS FOR COMPRESSIVE SENSING AND SPARSE SIGNAL PROCESSING
AUTHOR: ABIJITH J. KAMATH
abijithj@iisc.ac.in
'''
# %% IMPORT LIBRARIES
import numpy as np
# %% ACTIVATION FUNCTIONS
def shrinkage(x, lambd):
''' Soft-thresholding of x with threshold lambd '''
return np.maximum(0,x-lambd) - np.maximum(0,-x-lambd)
# %% SPARSE RECOVERY ALGORITHMS
def l0_ihta():
pass
def l1_bp():
pass
def l1_ista(y, A, lambd, max_iter=100, tol=1e-12):
'''
Sparse recovery using iterative soft-thresholding (ISTA)
:param y: Measurement vector
:param A: Sensing matrix
:param lambd: Penalty parameter
:param max_iter: Maximum number of iterations
:param tol: Tolerance of error
:returns: Sparse vector that solves y = Ax
'''
_, n = A.shape
x = np.zeros((n,1),dtype=complex)
AHA = np.conj(A.T).dot(A)
eigval, _ = np.linalg.eig(AHA)
t = 1.0/np.max(eigval)
errors = []
for _ in range(max_iter):
xold = x
x = shrinkage(x+t*np.conj(A.T)@(y-A@x),lambd*t)
error = np.linalg.norm(y-A@x)**2
errors.append(error)
if (np.linalg.norm(x-xold)/np.linalg.norm(xold)<tol):
break
return np.real(x), errors
def l1_fista():
pass
def l1_admm():
pass