Skip to content

Commit

Permalink
multilabel QUIRE speed up a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
yangarbiter committed Feb 24, 2017
1 parent 145e131 commit a92377e
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions libact/query_strategies/multilabel/multilabel_quire.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1.,

self.random_state_ = seed_random_state(random_state)

@profile
@inherit_docstring_from(QueryStrategy)
def make_query(self):
dataset = self.dataset
Expand All @@ -120,24 +121,21 @@ def make_query(self):
R = np.nan_to_num(R)

L = lamba * (np.linalg.pinv(np.kron(R, K) + lamba * np.eye(n*m)))
inv_L = np.linalg.pinv(L)

vecY = np.reshape(np.array([y for y in Y if y is not None]), (-1, 1))
invLuu = np.linalg.pinv(L[np.ix_(u, u)])

score = np.zeros((n, m))
vYLllvY = np.dot(np.dot(vecY.T, L[np.ix_(l, l)]), vecY)[0, 0]
U = np.dot(L[np.ix_(u, l)], vecY)
temp0 = -(np.dot(np.dot(U.T, invLuu), U))[0, 0]
for a in range(n):
for b in range(m):
s = b*n + a
U = np.dot(L[np.ix_(u, l)], vecY) + L[np.ix_(u, [s])]
temp1 = 2 * np.dot(L[[s], l], vecY) \
- np.dot(np.dot(U.T, invLuu), U)
U = np.dot(L[np.ix_(u, l)], vecY)
temp0 = -(np.dot(np.dot(U.T, invLuu), U))
score[a, b] = L[s, s] \
+ np.dot(np.dot(vecY.T, L[np.ix_(l, l)]),
vecY)[0, 0]\
+ np.max((temp1[0, 0], temp0[0, 0]))
tU = U + L[np.ix_(u, [s])]
temp1 = (2 * np.dot(L[[s], l], vecY) \
- np.dot(np.dot(tU.T, invLuu), tU))[0, 0]
score[a, b] = L[s, s] + vYLllvY + np.max((temp1, temp0))

score = np.sum(score, axis=1)

Expand Down

0 comments on commit a92377e

Please sign in to comment.