Skip to content

Commit

Permalink
Pep8ify code.
Browse files Browse the repository at this point in the history
  • Loading branch information
braniii committed Oct 8, 2024
1 parent 05dbd6d commit 33f8f85
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 111 deletions.
1 change: 1 addition & 0 deletions src/normi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Normalized mutual information"""

__all__ = ['NormalizedMI']

NORMS = {'joint', 'geometric', 'arithmetic', 'min', 'max'}
Expand Down
9 changes: 8 additions & 1 deletion src/normi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
All rights reserved.
"""

import click
import numpy as np

Expand Down Expand Up @@ -96,7 +97,13 @@
help='Activate verbose mode.',
)
def main(
input_file, output_basename, norm, inv_measure, n_dims, precision, verbose,
input_file,
output_basename,
norm,
inv_measure,
n_dims,
precision,
verbose,
):
# load file
if verbose:
Expand Down
30 changes: 19 additions & 11 deletions src/normi/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
All rights reserved.
"""

__all__ = ['NormalizedMI'] # noqa: WPS410

import numpy as np
Expand All @@ -32,7 +33,7 @@
PositiveFloat,
PositiveInt,
PositiveMatrix,
ArrayLikePositiveInt
ArrayLikePositiveInt,
)


Expand Down Expand Up @@ -221,7 +222,8 @@ def _reset(self) -> None:

@beartype
def nmi(
self, normalize_method: Optional[NormString] = None,
self,
normalize_method: Optional[NormString] = None,
) -> NormalizedMatrix:
"""Return the normalized mutual information matrix.
Expand Down Expand Up @@ -265,11 +267,13 @@ def nmi(

@beartype
def _kraskov_estimator(
self, X: List[Float2DArray],
self,
X: List[Float2DArray],
) -> Tuple[PositiveMatrix, FloatMatrix, FloatMatrix, FloatMatrix]:
"""Estimate the mutual information and entropies matrices."""
mi: PositiveMatrix = np.empty( # noqa: WPS317
(self._n_features, self._n_features), dtype=self._dtype,
(self._n_features, self._n_features),
dtype=self._dtype,
)
hxy: FloatMatrix = np.empty_like(mi)
hx: FloatMatrix = np.empty_like(mi)
Expand All @@ -286,7 +290,7 @@ def _kraskov_estimator(
hxy[idx_i, idx_i] = 1
hx[idx_i, idx_i] = 1
hy[idx_i, idx_i] = 1
for idx_j, xj in enumerate(X[idx_i + 1:], idx_i + 1):
for idx_j, xj in enumerate(X[idx_i + 1 :], idx_i + 1):
mi_ij, hxy_ij, hx_ij, hy_ij = kraskov_estimator(
xi,
xj,
Expand Down Expand Up @@ -337,9 +341,7 @@ def _scale_nearest_neighbor_distance(
if invariant_measure == 'radius':
return radii / np.mean(radii)
elif invariant_measure == 'volume':
return radii / (
np.mean(radii ** n_dims) ** (1 / n_dims)
)
return radii / (np.mean(radii**n_dims) ** (1 / n_dims))
elif invariant_measure == 'kraskov':
return radii
# This should never be reached
Expand Down Expand Up @@ -399,7 +401,9 @@ def kraskov_estimator(
# Here we rely on NearestNeighbors to select the fastest algorithm.
tree = KDTree(xy)
radii: FloatArray = tree.query(
xy, k=n_neighbors + 1, **kdtree_kwargs,
xy,
k=n_neighbors + 1,
**kdtree_kwargs,
)[0][:, 1:] # neglect self count
# take next smaller radii
radii: FloatArray = np.nextafter(radii[:, -1], 0)
Expand All @@ -415,8 +419,12 @@ def kraskov_estimator(
ny: FloatArray
nx, ny = [
KDTree(z).query_ball_point(
z, r=radii, return_length=True, **kdtree_kwargs,
) - 1 # fix self count
z,
r=radii,
return_length=True,
**kdtree_kwargs,
)
- 1 # fix self count
for z in (x, y)
]

Expand Down
18 changes: 12 additions & 6 deletions src/normi/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
All rights reserved.
"""

import numpy as np
from beartype.typing import List, Union
from beartype.vale import Is, IsAttr, IsEqual

from normi import INVMEASURES, NORMS

try: # for python <= 3.8 use typing_extensions
Expand All @@ -29,23 +31,27 @@ def _get_resolution(x):

def _allclose(x, y) -> bool:
"""Wrapper around np.allclose with dtype dependent atol."""
atol = np.max([
_get_resolution(x),
_get_resolution(y),
# default value of numpy
1e-8,
])
atol = np.max(
[
_get_resolution(x),
_get_resolution(y),
# default value of numpy
1e-8,
]
)
return np.allclose(x, y, atol=atol)


class NDim:
"""Class for creating Validators checking for desired dimensions."""

def __class_getitem__(self, ndim):
return IsAttr['ndim', IsEqual[ndim]]


class DType:
"""Class for creating Validators checking for desired dtype."""

def __class_getitem__(self, dtype):
return Is[lambda arr: np.issubdtype(arr.dtype, dtype)]

Expand Down
4 changes: 2 additions & 2 deletions src/normi/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
All rights reserved.
"""

__all__ = ['savetxt'] # noqa: WPS410

import datetime
Expand Down Expand Up @@ -37,8 +38,7 @@ def _get_rui() -> str:
}

return (
'This file was generated by nmi:\n{args}' +
'\n\n{date}, {user}@{pc}'
'This file was generated by nmi:\n{args}\n\n{date}, {user}@{pc}'
).format(**rui)


Expand Down
Loading

0 comments on commit 33f8f85

Please sign in to comment.