From a92377e8349ce8c31a79eedc8a366d0a2a6bc751 Mon Sep 17 00:00:00 2001 From: yangarbiter Date: Fri, 24 Feb 2017 14:50:45 +0800 Subject: [PATCH] multilabel QUIRE speed up a bit --- .../multilabel/multilabel_quire.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/libact/query_strategies/multilabel/multilabel_quire.py b/libact/query_strategies/multilabel/multilabel_quire.py index b83aa85..d9a13be 100644 --- a/libact/query_strategies/multilabel/multilabel_quire.py +++ b/libact/query_strategies/multilabel/multilabel_quire.py @@ -97,6 +97,7 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1., self.random_state_ = seed_random_state(random_state) + @profile @inherit_docstring_from(QueryStrategy) def make_query(self): dataset = self.dataset @@ -120,24 +121,21 @@ def make_query(self): R = np.nan_to_num(R) L = lamba * (np.linalg.pinv(np.kron(R, K) + lamba * np.eye(n*m))) - inv_L = np.linalg.pinv(L) vecY = np.reshape(np.array([y for y in Y if y is not None]), (-1, 1)) invLuu = np.linalg.pinv(L[np.ix_(u, u)]) score = np.zeros((n, m)) + vYLllvY = np.dot(np.dot(vecY.T, L[np.ix_(l, l)]), vecY)[0, 0] + U = np.dot(L[np.ix_(u, l)], vecY) + temp0 = -(np.dot(np.dot(U.T, invLuu), U))[0, 0] for a in range(n): for b in range(m): s = b*n + a - U = np.dot(L[np.ix_(u, l)], vecY) + L[np.ix_(u, [s])] - temp1 = 2 * np.dot(L[[s], l], vecY) \ - - np.dot(np.dot(U.T, invLuu), U) - U = np.dot(L[np.ix_(u, l)], vecY) - temp0 = -(np.dot(np.dot(U.T, invLuu), U)) - score[a, b] = L[s, s] \ - + np.dot(np.dot(vecY.T, L[np.ix_(l, l)]), - vecY)[0, 0]\ - + np.max((temp1[0, 0], temp0[0, 0])) + tU = U + L[np.ix_(u, [s])] + temp1 = (2 * np.dot(L[[s], l], vecY) \ + - np.dot(np.dot(tU.T, invLuu), tU))[0, 0] + score[a, b] = L[s, s] + vYLllvY + np.max((temp1, temp0)) score = np.sum(score, axis=1)