Skip to content

Commit

Permalink
Provide default for n param
Browse files Browse the repository at this point in the history
  • Loading branch information
JustinShenk committed May 29, 2019
1 parent 4c1479b commit a47c01b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 33 deletions.
69 changes: 38 additions & 31 deletions src/closely/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,37 @@

def solution(
array: np.ndarray,
n: Optional[int] = None,
n: Optional[int] = 5,
metric="euclidean",
quantile: Optional[float] = None,
):
"""Solve the closest pairs problem.
Args:
array (np.ndarray): N x M
n (int, optional): number of closest pairs
n (int, optional): number of closest pairs; either n or quantile must be defined
metric (str): distance metric
quantile (float, optional): between 0 and 1
Returns:
pairs (list of tuples): closest pairs of points
pairs (np.ndarray): closest pairs of points
distances (np.ndarray): distances between pairs
"""
if quantile is not None:
index = get_index_of_quantile(array, quantile)
pairs, distances = closest_k_pairs(array, kth=index, metric=metric)
else:
elif n is not None:
pairs, distances = closest_k_pairs(array, kth=n, metric=metric)
return pairs, distances


def closest_k_pairs(array: np.ndarray, kth: int = 3, metric: str = "euclidean"):
def closest_k_pairs(array: np.ndarray, kth:int, metric: str = "euclidean"):
"""Get closest k-pairs in a matrix.
Args:
array (np.ndarray): n instances x m features matrix
kth (int): k lowest pairs
metric (str): distance metric
Returns:
pairs (list of tuples): coordinates of nearest pairs, ordered
pairs (np.ndarray): coordinates of nearest pairs, ordered (eg, [[0,2],[3,5],...]
distances (np.ndarray): 1-d array of distances for each pair, sorted
"""
Expand All @@ -53,11 +53,21 @@ def closest_k_pairs(array: np.ndarray, kth: int = 3, metric: str = "euclidean"):

pairs = list(zip(coord1[:kth], coord2[:kth]))
distances = dist_mat[coord1[:kth], coord2[:kth]]
return pairs, distances
indices = np.argsort(distances)
pairs = np.array(pairs)[indices]
return pairs, distances[indices]


def get_index_of_quantile(dist_mat: np.ndarray, quantile: float):
"""Returns index of `quantile` in `dist_mat`."""
"""Returns index of `quantile` in `dist_mat`.
Args:
dist_mat (np.ndarray): square distance matrix
quantile (float): quantile
Returns:
index (int): index of quantile
"""
flat_dist_mat = dist_mat.flatten()
flat_dist_mat.sort()

Expand All @@ -67,13 +77,13 @@ def get_index_of_quantile(dist_mat: np.ndarray, quantile: float):


def seriation(Z, N, cur_index):
"""
input:
- Z is a hierarchical tree (dendrogram)
- N is the number of points given to the clustering process
- cur_index is the position in the tree for the recursive traversal
output:
- order implied by the hierarchical tree Z
"""Order a distance matrix with a hierarchical clustering dendrogram
Args:
Z (np.ndarray): hierarchical tree (dendrogram)
N (int): number of points given to the clustering process
cur_index (int): position in the tree for the recursive traversal
Returns:
order (list of ints): order implied by the hierarchical tree Z
seriation computes the order implied by a hierarchical tree (dendrogram)
"""
Expand All @@ -86,22 +96,19 @@ def seriation(Z, N, cur_index):


def compute_serial_matrix(dist_mat, method="ward"):
"""
input:
- dist_mat is a distance matrix
- method = ["ward","single","average","complete"]
output:
- seriated_dist is the input dist_mat,
but with re-ordered rows and columns
according to the seriation, i.e. the
order implied by the hierarchical tree
- res_order is the order implied by
the hierarhical tree
- res_linkage is the hierarhical tree (dendrogram)
compute_serial_matrix transforms a distance matrix into
a sorted distance matrix according to the order implied
by the hierarchical tree (dendrogram)
"""Transforms a distance matrix into a sorted distance matrix according to
the order implied by the hierarchical tree (dendrogram)
Args:
dist_mat (np.ndarray): distance matrix
method (str): one of "ward","single","average","complete"
Returns:
seriated_dist (np.ndarray): the input dist_mat, but with re-ordered rows and columns
according to the seriation, i.e. the order implied by the hierarchical tree
res_order (np.ndarray): order implied by the hierarhical tree
res_linkage (np.ndarray): hierarhical tree (dendrogram)
"""
try:
from fastcluster import linkage
Expand Down
4 changes: 2 additions & 2 deletions tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
def test_solution(n_points, n_dim):
X = np.random.random((1000, n_dim))
pairs, distances = closely.solve(X, n=n_points)
assert isinstance(pairs, list)
assert isinstance(pairs, np.ndarray)
assert len(pairs) >= n_points
assert isinstance(distances, np.ndarray)


def test_quantile():
X = np.random.random((1000, 8))
pairs, distances = closely.solve(X, quantile=0.01)
assert isinstance(pairs, list)
assert isinstance(pairs, np.ndarray)
assert isinstance(distances, np.ndarray)


Expand Down

0 comments on commit a47c01b

Please sign in to comment.