From a47c01b429f981ae2563b57a2375ed400632e44f Mon Sep 17 00:00:00 2001 From: Justin Shenk Date: Wed, 29 May 2019 18:54:04 +0200 Subject: [PATCH] Provide default for n param --- src/closely/main.py | 69 +++++++++++++++++++++++++-------------------- tests/test_solve.py | 4 +-- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/src/closely/main.py b/src/closely/main.py index 6b25dd1..7f5358b 100644 --- a/src/closely/main.py +++ b/src/closely/main.py @@ -6,37 +6,37 @@ def solution( array: np.ndarray, - n: Optional[int] = None, + n: Optional[int] = 5, metric="euclidean", quantile: Optional[float] = None, ): """Solve the closest pairs problem. Args: array (np.ndarray): N x M - n (int, optional): number of closest pairs + n (int, optional): number of closest pairs; either n or quantile must be defined metric (str): distance metric quantile (float, optional): between 0 and 1 Returns: - pairs (list of tuples): closest pairs of points + pairs (np.ndarray): closest pairs of points distances (np.ndarray): distances between pairs """ if quantile is not None: index = get_index_of_quantile(array, quantile) pairs, distances = closest_k_pairs(array, kth=index, metric=metric) - else: + elif n is not None: pairs, distances = closest_k_pairs(array, kth=n, metric=metric) return pairs, distances -def closest_k_pairs(array: np.ndarray, kth: int = 3, metric: str = "euclidean"): +def closest_k_pairs(array: np.ndarray, kth:int, metric: str = "euclidean"): """Get closest k-pairs in a matrix. Args: array (np.ndarray): n instances x m features matrix kth (int): k lowest pairs metric (str): distance metric Returns: - pairs (list of tuples): coordinates of nearest pairs, ordered + pairs (np.ndarray): coordinates of nearest pairs, ordered (eg, [[0,2],[3,5],...] distances (np.ndarray): 1-d array of distances for each pair, sorted """ @@ -53,11 +53,21 @@ def closest_k_pairs(array: np.ndarray, kth: int = 3, metric: str = "euclidean"): pairs = list(zip(coord1[:kth], coord2[:kth])) distances = dist_mat[coord1[:kth], coord2[:kth]] - return pairs, distances + indices = np.argsort(distances) + pairs = np.array(pairs)[indices] + return pairs, distances[indices] def get_index_of_quantile(dist_mat: np.ndarray, quantile: float): - """Returns index of `quantile` in `dist_mat`.""" + """Returns index of `quantile` in `dist_mat`. + Args: + dist_mat (np.ndarray): square distance matrix + quantile (float): quantile + + Returns: + index (int): index of quantile + + """ flat_dist_mat = dist_mat.flatten() flat_dist_mat.sort() @@ -67,13 +77,13 @@ def get_index_of_quantile(dist_mat: np.ndarray, quantile: float): def seriation(Z, N, cur_index): - """ - input: - - Z is a hierarchical tree (dendrogram) - - N is the number of points given to the clustering process - - cur_index is the position in the tree for the recursive traversal - output: - - order implied by the hierarchical tree Z + """Order a distance matrix with a hierarchical clustering dendrogram + Args: + Z (np.ndarray): hierarchical tree (dendrogram) + N (int): number of points given to the clustering process + cur_index (int): position in the tree for the recursive traversal + Returns: + order (list of ints): order implied by the hierarchical tree Z seriation computes the order implied by a hierarchical tree (dendrogram) """ @@ -86,22 +96,19 @@ def seriation(Z, N, cur_index): def compute_serial_matrix(dist_mat, method="ward"): - """ - input: - - dist_mat is a distance matrix - - method = ["ward","single","average","complete"] - output: - - seriated_dist is the input dist_mat, - but with re-ordered rows and columns - according to the seriation, i.e. the - order implied by the hierarchical tree - - res_order is the order implied by - the hierarhical tree - - res_linkage is the hierarhical tree (dendrogram) - - compute_serial_matrix transforms a distance matrix into - a sorted distance matrix according to the order implied - by the hierarchical tree (dendrogram) + """Transforms a distance matrix into a sorted distance matrix according to + the order implied by the hierarchical tree (dendrogram) + + Args: + dist_mat (np.ndarray): distance matrix + method (str): one of "ward","single","average","complete" + + Returns: + seriated_dist (np.ndarray): the input dist_mat, but with re-ordered rows and columns + according to the seriation, i.e. the order implied by the hierarchical tree + res_order (np.ndarray): order implied by the hierarhical tree + res_linkage (np.ndarray): hierarhical tree (dendrogram) + """ try: from fastcluster import linkage diff --git a/tests/test_solve.py b/tests/test_solve.py index c977678..7c40c97 100644 --- a/tests/test_solve.py +++ b/tests/test_solve.py @@ -11,7 +11,7 @@ def test_solution(n_points, n_dim): X = np.random.random((1000, n_dim)) pairs, distances = closely.solve(X, n=n_points) - assert isinstance(pairs, list) + assert isinstance(pairs, np.ndarray) assert len(pairs) >= n_points assert isinstance(distances, np.ndarray) @@ -19,7 +19,7 @@ def test_solution(n_points, n_dim): def test_quantile(): X = np.random.random((1000, 8)) pairs, distances = closely.solve(X, quantile=0.01) - assert isinstance(pairs, list) + assert isinstance(pairs, np.ndarray) assert isinstance(distances, np.ndarray)