diff --git a/libact/query_strategies/multilabel/multilabel_quire.py b/libact/query_strategies/multilabel/multilabel_quire.py index e626e62..5c2e76f 100644 --- a/libact/query_strategies/multilabel/multilabel_quire.py +++ b/libact/query_strategies/multilabel/multilabel_quire.py @@ -81,13 +81,10 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1., X, _ = zip(*dataset.get_entries()) self.kernel = kernel if self.kernel == 'rbf': - self.K = rbf_kernel(X=X, Y=X, gamma=kwargs.pop('gamma', 1.)) + self.K = rbf_kernel(X=X, Y=X, gamma=gamma) elif self.kernel == 'poly': - self.K = polynomial_kernel(X=X, - Y=X, - coef0=kwargs.pop('coef0', 1), - degree=kwargs.pop('degree', 3), - gamma=kwargs.pop('gamma', 1.)) + self.K = polynomial_kernel(X=X, Y=X, coef0=coef0, degree=degree, + gamma=gamma) elif self.kernel == 'linear': self.K = linear_kernel(X=X, Y=X) elif hasattr(self.kernel, '__call__'): @@ -99,8 +96,9 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1., _, lbled_Y = zip(*dataset.get_labeled_entries()) + self.n_labels = np.shape(lbled_Y)[1] n = len(X) - m = np.shape(lbled_Y)[1] + m = self.n_labels # label correlation matrix R = np.corrcoef(np.array(lbled_Y).T) R = np.nan_to_num(R) @@ -108,39 +106,60 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1., self.L = lamba * (np.linalg.pinv(self.RK + lamba * np.eye(n*m))) - @inherit_docstring_from(QueryStrategy) - def make_query(self): - dataset = self.dataset - X, Y = zip(*dataset.get_entries()) - _, lbled_Y = zip(*dataset.get_labeled_entries()) - - X = np.array(X) - RK = self.RK - n_instance = len(X) - m = np.shape(lbled_Y)[1] - lamba = self.lamba - + def _get_index(self): + _, Y = zip(*self.dataset.get_entries()) + n_instance = len(Y) + m = self.n_labels # index for labeled and unlabeled instance l_id = [] a_id = [] for i in range(n_instance * m): - if Y[i%n_instance] is None: + if Y[i // m] is None: a_id.append(i) else: l_id.append(i) + return a_id, l_id + + #def update(self, entry_id, label): + # # calculate invLaa + # invLaa = self.invLaa + # # idx before update + # a_id, l_id = self.idxs + # m = len(label) + # # assert len(np.where(np.array(a_id) == entry_id*m)[0]) == 1 + # idx = np.where(np.array(a_id) == entry_id*m)[0][0] + # for i in range(m): + # D = np.delete(np.delete(invLaa, idx, axis=0), idx, axis=1) + # b = np.delete(invLaa, idx, axis=0)[:, idx] + # # invLuu + # invLaa = D - 1./invLaa[idx, idx] * np.dot(b, b.T) + # self.invLaa = invLaa + + @inherit_docstring_from(QueryStrategy) + def make_query(self): + dataset = self.dataset + X, Y = zip(*dataset.get_entries()) + X = np.array(X) + n_instance = len(X) + m = self.n_labels + RK = self.RK + lamba = self.lamba L = self.L - vecY = np.reshape(np.array([y for y in Y if y is not None]).T, (-1, 1)) - detLaa = np.linalg.det(L[np.ix_(a_id, a_id)]) - #invLaa = np.linalg.pinv(L[np.ix_(a_id, a_id)]) - invLaa = (lamba * np.eye(len(a_id)) + RK[np.ix_(a_id, a_id)]) \ + + a_id, l_id = self._get_index() + # invLaa = np.linalg.pinv(L[np.ix_(a_id, a_id)]) + invLaa = ((lamba * np.eye(len(a_id)) + RK[np.ix_(a_id, a_id)]) \ - np.dot(np.dot(RK[np.ix_(a_id, l_id)], np.linalg.pinv(lamba * np.eye(len(l_id)) \ + RK[np.ix_(l_id, l_id)])), - RK[np.ix_(l_id, a_id)]) + RK[np.ix_(l_id, a_id)])) / lamba + + vecY = np.reshape(np.array([y for y in Y if y is not None]).T, (-1, 1)) + detLaa = np.linalg.det(L[np.ix_(a_id, a_id)]) + score = np.zeros(len(a_id)) b = np.zeros((len(a_id)-1)) - score = [] D = np.zeros((len(a_id)-1, len(a_id)-1)) D[...] = invLaa[1:, 1:] for i, s in enumerate(a_id): @@ -162,13 +181,13 @@ def make_query(self): b[i:] = invLaa[i+1:, i] invLuu = D - 1./invLaa[i, i] * np.dot(b, b.T) - score.append(L[s, s] - detLaa / L[s, s] \ - + 2 * np.abs(np.dot(L[s, l_id] \ - - np.dot(np.dot(L[s, u_id], invLuu), - L[np.ix_(u_id, l_id)]), vecY))) + score[i] = L[s, s] - detLaa / L[s, s] \ + + 2 * np.abs(np.dot(L[s, l_id] \ + - np.dot(np.dot(L[s, u_id], invLuu), + L[np.ix_(u_id, l_id)]), vecY)) - score = np.sum(np.array(score).reshape(m, -1).T, axis=1) + score = np.sum(score.reshape(m, -1).T, axis=1) ask_idx = self.random_state_.choice(np.where(score == np.min(score))[0]) - return a_id[ask_idx] + return a_id[ask_idx] // m diff --git a/libact/query_strategies/multilabel/tests/test_multilabel_quire.py b/libact/query_strategies/multilabel/tests/test_multilabel_quire.py new file mode 100644 index 0000000..17040d1 --- /dev/null +++ b/libact/query_strategies/multilabel/tests/test_multilabel_quire.py @@ -0,0 +1,26 @@ +import unittest + +from numpy.testing import assert_array_equal +import numpy as np + +from libact.base.dataset import Dataset +from libact.query_strategies.multilabel import MultilabelQUIRE +from libact.utils import run_qs + + +class MultilabelQUIRETestCase(unittest.TestCase): + """Variance reduction test case using artifitial dataset""" + def setUp(self): + self.X = [[-2, -1], [1, 1], [-1, -2], [-1, -1], [1, 2], [2, 1]] + self.y = [[0, 1], [1, 0], [0, 1], [1, 0], [1, 0], [1, 1]] + self.quota = 4 + + def test_multilabel_quire(self): + trn_ds = Dataset(self.X, (self.y[:2] + [None] * (len(self.y) - 2))) + qs = MultilabelQUIRE(trn_ds) + qseq = run_qs(trn_ds, qs, self.y, self.quota) + assert_array_equal(qseq, np.array([2, 3, 4, 5])) + + +if __name__ == '__main__': + unittest.main() diff --git a/libact/utils/__init__.py b/libact/utils/__init__.py index d58fc0d..93af3f0 100644 --- a/libact/utils/__init__.py +++ b/libact/utils/__init__.py @@ -50,3 +50,33 @@ def calc_cost(y, yhat, cost_matrix): ith class and prediction as jth class. """ return np.mean(cost_matrix[list(y), list(yhat)]) + +def run_qs(trn_ds, qs, truth, quota): + """Run query strategy on specified dataset and return quering sequence. + + Parameters + ---------- + trn_ds : Dataset object + The dataset to be run on. + + qs : QueryStrategy instance + The active learning algorith to be run. + + truth : array-like + The true label. + + quota : int + Number of iterations to run + + Returns + ------- + qseq : numpy array, shape (quota,) + The numpy array of entry_id representing querying sequence. + """ + ret = [] + for _ in range(quota): + ask_id = qs.make_query() + trn_ds.update(ask_id, truth[ask_id]) + + ret.append(ask_id) + return np.array(ret)