Skip to content

Commit

Permalink
Merge branch 'main' of github.com:moldyn/normi
Browse files Browse the repository at this point in the history
  • Loading branch information
braniii committed Oct 8, 2024
2 parents 4143ec1 + 41e2fb3 commit 9e9b3f1
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 32 deletions.
13 changes: 13 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file

version: 2
updates:
# Maintain dependencies for GitHub Actions
- package-ecosystem: "github-actions"
# Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.)
directory: "/"
schedule:
interval: "weekly"
8 changes: 4 additions & 4 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
Expand All @@ -64,7 +64,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v2
uses: github/codeql-action/autobuild@v3

# ℹ️ Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
Expand All @@ -77,6 +77,6 @@ jobs:
# ./location_of_script_within_repo/buildscript.sh

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"
4 changes: 2 additions & 2 deletions .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ jobs:
env:
PYTHON: ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -30,7 +30,7 @@ jobs:
run: |
pytest --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
<img src="https://img.shields.io/pypi/pyversions/normi" /></a>
<a href="https://moldyn.github.io/normi" alt="Docs">
<img src="https://img.shields.io/badge/MkDocs-Documentation-brightgreen" /></a>
<a href="https://doi.org/10.1063/5.0217960" alt="doi">
<img src="https://img.shields.io/badge/doi-10.1063%2F5.0217960-blue" /></a>
<a href="https://arxiv.org/abs/2405.04980" alt="arXiv">
<img src="https://img.shields.io/badge/arXiv-2405.04980-red" /></a>
<a href="https://github.com/moldyn/normi/blob/main/LICENSE" alt="License">
Expand All @@ -44,8 +46,8 @@ This software provides an extension to the Kraskov-Estimator to allow normalizin
The method was published in:
> **Accurate estimation of the normalized mutual information of multidimensional data**
> D. Nagel, G. Diez, and G. Stock,
> *arXiv* **2024**
> doi: [10.48550/arXiv.2405.04980](https://doi.org/10.48550/arXiv.2405.04980)
> *J. Chem. Phys.* **2024** 161, 054108
> doi: [10.1063/5.0217960](https://doi.org/10.1063/5.0217960)
If you use this software package, please cite the above mentioned paper.

Expand Down
56 changes: 37 additions & 19 deletions src/normi/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PositiveFloat,
PositiveInt,
PositiveMatrix,
ArrayLikePositiveInt
)


Expand All @@ -40,7 +41,7 @@ class NormalizedMI(BaseEstimator):
Parameters
----------
n_dims : int, default=1
n_dims : int or list of ints, default=1
Dimensionality of input vectors.
normalize_method : str, default='geometric'
Determines the normalization factor for the mutual information:<br/>
Expand Down Expand Up @@ -94,15 +95,15 @@ class NormalizedMI(BaseEstimator):
def __init__(
self,
*,
n_dims: PositiveInt = 1,
n_dims: Union[ArrayLikePositiveInt, PositiveInt] = 1,
normalize_method: NormString = 'geometric',
invariant_measure: InvMeasureString = 'volume',
k: PositiveInt = 5,
n_jobs: Int = -1,
verbose: bool = True,
):
"""Initialize NormalizedMI class."""
self.n_dims: PositiveInt = n_dims
self.n_dims: Union[ArrayLikePositiveInt, PositiveInt] = n_dims
self.normalize_method: NormString = normalize_method
self.invariant_measure: InvMeasureString = invariant_measure
self.k: PositiveInt = k
Expand Down Expand Up @@ -131,21 +132,26 @@ def fit(
"""
self._reset()

_check_X(X=X, n_dims=self.n_dims)

# define number of features and samples
n_samples: int
n_cols: int
n_samples, n_cols = X.shape
self._n_samples: int = n_samples
self._n_features: int = n_cols // self.n_dims
if isinstance(self.n_dims, int):
self._n_features: int = n_cols // self.n_dims
else:
self._n_features: int = len(self.n_dims)

# scale input
X = StandardScaler().fit_transform(X)
X = np.split(X, self._n_features, axis=1)
if isinstance(self.n_dims, int):
X = np.split(X, self._n_features, axis=1)
else:
X = np.split(X, np.cumsum(self.n_dims), axis=1)[:-1]

self.mi_: PositiveMatrix
self.mi_: FloatMatrix
self.hxy_: FloatMatrix
self.hx_: FloatMatrix
self.hy_: FloatMatrix
Expand Down Expand Up @@ -436,11 +442,31 @@ def kraskov_estimator(


@beartype
def _check_X(X: Float2DArray, n_dims: PositiveInt):
def _check_X(X: Float2DArray, n_dims: Union[ArrayLikePositiveInt, PositiveInt]):
"""Sanity check of the input to ensure correct format and dimension."""
# parse data
if X.shape[1] < 2 * n_dims:
raise ValueError('At least two variables need to be provided')
_, n_cols = X.shape

if isinstance(n_dims, int):
# When n_dims is a single integer
if n_cols < 2 * n_dims:
raise ValueError('At least two variables need to be provided')

n_features = n_cols // n_dims
if n_cols != n_features * n_dims:
raise ValueError(
'The number of provided columns needs to be a multiple of the '
'specified dimensionality `n_dims`.',
)
else:
# When n_dims is a list of integers
if len(n_dims) < 2:
raise ValueError('At least two variables need to be provided')

if np.sum(n_dims) != n_cols:
raise ValueError(
'The number of provided columns needs to match with the sum of `n_dims`.',
)

stds = np.std(X, axis=0)
invalid_stds = (stds == 0) | (np.isnan(stds))
if np.any(invalid_stds):
Expand All @@ -449,11 +475,3 @@ def _check_X(X: Float2DArray, n_dims: PositiveInt):
f'Columns {idxs} have a standard deviation of zero or NaN. '
'These columns cannot be used for estimating the NMI.',
)

_, n_cols = X.shape
n_features = n_cols // n_dims
if n_cols != n_features * n_dims:
raise ValueError(
'The number of provided columns needs to be a multiple of the '
'specified dimensionality `n_dims`.',
)
3 changes: 3 additions & 0 deletions src/normi/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __class_getitem__(self, dtype):
DTypeLike = Annotated[type, IsDTypeLike]

# array datatypes
IntNDArray = Annotated[np.ndarray, DType[np.integer]]
ArrayLikeInt = Union[List[int], IntNDArray]
ArrayLikePositiveInt = Annotated[ArrayLikeInt, IsPositive]
FloatNDArray = Annotated[np.ndarray, DType[np.floating]]
ArrayLikeFloat = Union[List[float], FloatNDArray]
FloatArray = Annotated[FloatNDArray, NDim[1]]
Expand Down
11 changes: 11 additions & 0 deletions tests/test__estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def test__scale_nearest_neighbor_distance(
@pytest.mark.parametrize('X, n_dims, error', [
(np.random.uniform(size=(10, 9)), 1, None),
(np.random.uniform(size=(10, 9)), 3, None),
(np.random.uniform(size=(10, 9)), np.array([3, 3, 3]), None),
(np.random.uniform(size=(10, 9)), np.array([4, 3, 2]), None),
(np.random.uniform(size=(10, 9)), np.array([3, 2, 3]), ValueError),
(np.random.uniform(size=(10, 9)), np.array([9]), ValueError),
(np.random.uniform(size=(10, 9)), 5, ValueError),
(np.random.uniform(size=(10, 9)), 2, ValueError),
(np.zeros((10, 9)).astype(float), 1, ValueError),
(np.vander((1, 2, 3, 4), 3).astype(float), 1, ValueError),
Expand Down Expand Up @@ -152,6 +157,12 @@ def test__reset(normalize_method, X, kwargs):
X1_result('geometric', 'volume'),
None,
),
(
X1(),
{'n_dims': np.array([1, 1]), 'normalize_method': 'geometric', 'invariant_measure': 'volume'},
X1_result('geometric', 'volume'),
None,
),
])
def test_NormalizedMI(X, kwargs, result, error):
# cast radii to float to fulfill beartype typing req.
Expand Down

0 comments on commit 9e9b3f1

Please sign in to comment.