From 65791a432ddad0a9130e8878188a84bfd25fbff8 Mon Sep 17 00:00:00 2001 From: yangarbiter Date: Wed, 1 Mar 2017 20:56:46 +0800 Subject: [PATCH] update multilabel QUIRE --- .../multilabel/multilabel_quire.py | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/libact/query_strategies/multilabel/multilabel_quire.py b/libact/query_strategies/multilabel/multilabel_quire.py index 7dba8f6..e626e62 100644 --- a/libact/query_strategies/multilabel/multilabel_quire.py +++ b/libact/query_strategies/multilabel/multilabel_quire.py @@ -104,9 +104,9 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1., # label correlation matrix R = np.corrcoef(np.array(lbled_Y).T) R = np.nan_to_num(R) + self.RK = np.kron(R, self.K) - self.L = lamba * (np.linalg.pinv(np.kron(R, self.K) \ - + lamba * np.eye(n*m))) + self.L = lamba * (np.linalg.pinv(self.RK + lamba * np.eye(n*m))) @inherit_docstring_from(QueryStrategy) def make_query(self): @@ -115,7 +115,7 @@ def make_query(self): _, lbled_Y = zip(*dataset.get_labeled_entries()) X = np.array(X) - K = self.K + RK = self.RK n_instance = len(X) m = np.shape(lbled_Y)[1] lamba = self.lamba @@ -132,18 +132,41 @@ def make_query(self): L = self.L vecY = np.reshape(np.array([y for y in Y if y is not None]).T, (-1, 1)) detLaa = np.linalg.det(L[np.ix_(a_id, a_id)]) - + #invLaa = np.linalg.pinv(L[np.ix_(a_id, a_id)]) + invLaa = (lamba * np.eye(len(a_id)) + RK[np.ix_(a_id, a_id)]) \ + - np.dot(np.dot(RK[np.ix_(a_id, l_id)], + np.linalg.pinv(lamba * np.eye(len(l_id)) \ + + RK[np.ix_(l_id, l_id)])), + RK[np.ix_(l_id, a_id)]) + + b = np.zeros((len(a_id)-1)) score = [] + D = np.zeros((len(a_id)-1, len(a_id)-1)) + D[...] = invLaa[1:, 1:] for i, s in enumerate(a_id): + # L -> s, Laa -> i u_id = a_id[:i] + a_id[i+1:] - invLuu = L[np.ix_(u_id, u_id)] \ - - 1./L[s, s] * np.dot(L[u_id, s], L[u_id, s].T) + #D = np.delete(np.delete(invLaa, i, axis=0), i, axis=1) + if i > 0: + D[(i-1), :i] = invLaa[(i-1), :i] + D[(i-1), i:] = invLaa[(i-1), (i+1):] + D[:i, (i-1)] = invLaa[:i, (i-1)] + D[i:, (i-1)] = invLaa[(i+1):, (i-1)] + #D[:i, :i] = invLaa[:i, :i] + #D[i:, i:] = invLaa[i+1:, i+1:] + #D[:i, i:] = invLaa[:i, i+1:] + #D[i:, :i] = invLaa[i+1:, :i] + + #b = np.delete(invLaa, i, axis=0)[:, i] + b[:i] = invLaa[:i, i] + b[i:] = invLaa[i+1:, i] + invLuu = D - 1./invLaa[i, i] * np.dot(b, b.T) + score.append(L[s, s] - detLaa / L[s, s] \ - + 2 * np.abs(np.dot(L[np.ix_([s], l_id)] \ + + 2 * np.abs(np.dot(L[s, l_id] \ - np.dot(np.dot(L[s, u_id], invLuu), - L[np.ix_(u_id, l_id)]), vecY))[0][0]) + L[np.ix_(u_id, l_id)]), vecY))) - import ipdb; ipdb.set_trace() score = np.sum(np.array(score).reshape(m, -1).T, axis=1) ask_idx = self.random_state_.choice(np.where(score == np.min(score))[0])