Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve Full Trajectory Information in AnalysisBase #4892

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The rules for this file:

-------------------------------------------------------------------------------
??/??/?? IAlibay, ChiahsinChu, RMeli, tanishy7777, talagayev, tylerjereddy,
marinegor
marinegor, yuxuanzhuang

* 2.9.0

Expand All @@ -38,6 +38,11 @@ Enhancements
Changes
* Changed `fasteners` dependency to `filelock` (Issue #4797, PR #4800)
* Codebase is now formatted with black (version `24`) (PR #4886)
* Added _global_slicer to make it possible to retrieve information
of the full trajectory slice in AnalysisBase. (Issue #4891 PR #4892)
* Added _global_frame_index to to keep track of frame iteration
number for the full trajectory slice in `single_frame` in AnalysisBase
(Issue #4891 PR #4892)

Deprecations

Expand Down
9 changes: 9 additions & 0 deletions package/MDAnalysis/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,13 @@ def _setup_frames(
each of the workers and gets executed twice: one time in
:meth:`_setup_frames` for the whole trajectory, second time in
:meth:`_compute` for each of the computation groups.

.. versionchanged:: 2.9.0
Add `self._global_slicer` attribute to store the slicer for the
whole trajectory.
"""
slicer = self._define_run_frames(trajectory, start, stop, step, frames)
self._global_slicer = slicer
self._prepare_sliced_trajectory(slicer)

def _single_frame(self):
Expand All @@ -452,6 +457,9 @@ def _single_frame(self):
Attributes accessible during your calculations:

- ``self._frame_index``: index of the frame in results array
Note that this is not the same as the frame number in the trajectory
- ``self._global_frame_index``: index of the frame in the trajectory
This is useful for parallel runs, where you can't rely on the
- ``self._ts`` -- Timestep instance
- ``self._sliced_trajectory`` -- trajectory that you're iterating over
- ``self.results`` -- :class:`MDAnalysis.analysis.results.Results` instance
Expand Down Expand Up @@ -537,6 +545,7 @@ def _compute(
)
):
self._frame_index = idx # accessed later by subclasses
self._global_frame_index = indexed_frames[idx, 0]
self._ts = ts
self.frames[idx] = ts.frame
self.times[idx] = ts.time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,29 @@ For MDAnalysis developers
From a developer point of view, there are a few methods that are important in
order to understand how parallelization is implemented:

#. :meth:`MDAnalysis.analysis.base.AnalysisBase._define_run_frames`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_frames`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._prepare_sliced_trajectory`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._configure_backend`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_computation_groups`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._compute`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._get_aggregator`

The first two methods share the functionality of :meth:`_setup_frames`.
:meth:`_define_run_frames` is run once during analysis, as it checks that input
parameters `start`, `stop`, `step` or `frames` are consistent with the given
trajectory and prepares the ``slicer`` object that defines the iteration
pattern through the trajectory. :meth:`_prepare_sliced_trajectory` assigns to
:meth:`_setup_frames` is run once during analysis :attr:`run()`, as it checks that input
parameters :attr:`start`, :attr:`stop`, :attr:`step` or :attr:`frames` are consistent with the given
trajectory and prepares the :attr:`slicer` object that defines the iteration
pattern through the trajectory with :meth:`_define_run_frames`.
The attribute :attr:`self._global_slicer` is assigned based on the `slicer`.
Users can later access the full sliced trajectory being analyzed via
:attr:`self._trajectory[self._global_slicer]`.

:meth:`_prepare_sliced_trajectory` assigns to
the :attr:`self._sliced_trajectory` attribute, computes the number of frames in
it, and fills the :attr:`self.frames` and :attr:`self.times` arrays. In case
the computation will be later split between other processes, this method will
be called again on each of the computation groups.
be called again on each of the computation groups. In parallel analysis,
:attr:`self._sliced_trajectory` represents a split of the original sliced
trajectory, and :attr:`self.n_frames` is the number of frames in each split
computation group (not the total number of frames in the sliced trajectory).

The method :meth:`_configure_backend` performs basic health checks for a given
analysis class -- namely, it compares a given backend (if it's a :class:`str`
Expand All @@ -155,7 +162,13 @@ analysis get initialized with the :meth:`_prepare` method. Then the function
iterates over :attr:`self._sliced_trajectory`, assigning
:attr:`self._frame_index` and :attr:`self._ts` as frame index (within a
computation group) and timestamp, and also setting respective
:attr:`self.frames` and :attr:`self.times` array values.
:attr:`self.frames` and :attr:`self.times` array values. Additionally,
:attr:`self._global_frame_index` is assigned the global frame index
within the full sliced trajectory (:attr:`self._trajectory[self._global_slicer]`).
This global frame index is particularly useful for analyses requiring it, such as
:class:`MDAnalysis.analysis.diffusionmap.DistanceMatrix` that needs to know the
frame index in the full trajectory.
See :ref:`retrieving-correct-frame-index` for more details.

After :meth:`_compute` has finished, the main analysis instance calls the
:meth:`_get_aggregator` method, which merges the :attr:`self.results`
Expand Down Expand Up @@ -357,6 +370,82 @@ In this way, you will override the check for supported backends.
with a supported backend. When reporting *always mention if you used*
``unsupported_backend=True``.

.. _retrieving-correct-frame-index:
Retrieving correct frame index in parallel analysis
===================================================

To retrieve the correct frame index during parallel analysis, use the
:attr:`self._global_frame_index` attribute. This attribute represents the global
frame index within the full sliced trajectory
(:attr:`self._trajectory[self._global_slicer]`).

For an example illustrating when to use :attr:`_frame_index` versus
:attr:`_global_frame_index` and :attr:`self._global_slicer`,
see the following code snippet:

.. code-block:: python

from MDAnalysis.analysis.base import AnalysisBase
from MDAnalysis.analysis.results import ResultsGroup

class MyAnalysis(AnalysisBase):
_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
"""Define the supported backends for the analysis."""
return ('serial', 'multiprocessing', 'dask')

def _prepare(self):
"""Initialize result attributes and compute global frame count."""
self.results.frame_index = []
self.results.global_frame_index = []
self.results.n_frames = []
self.results.global_n_frames = []
self.global_n_frames = len(self._trajectory[self._global_slicer])

def _single_frame(self):
"""Process a single frame during the analysis."""
frame_index = self._frame_index
global_frame_index = self._global_frame_index

# Append results for the current frame
self.results.frame_index.append(frame_index)
self.results.global_frame_index.append(global_frame_index)
self.results.n_frames.append(self.n_frames)
self.results.global_n_frames.append(self.global_n_frames)

def _get_aggregator(self):
"""Return an aggregator to combine results from multiple workers."""
return ResultsGroup(
lookup={
'frame_index': ResultsGroup.flatten_sequence,
'global_frame_index': ResultsGroup.flatten_sequence,
'n_frames': ResultsGroup.flatten_sequence,
'global_n_frames': ResultsGroup.flatten_sequence,
}
)

# Example usage: serial analysis
ana = MyAnalysis(u.trajectory)
ana.run(step=2)
print(ana.results)
# Output:
# {'frame_index': [0, 1, 2, 3, 4],
# 'global_frame_index': [0, 1, 2, 3, 4],
# 'n_frames': [5, 5, 5, 5, 5],
# 'global_n_frames': [5, 5, 5, 5, 5]}

# Example usage: parallel analysis
ana = MyAnalysis(u.trajectory)
ana.run(step=2, backend='dask', n_workers=2)
print(ana.results)
# Output:
# {'frame_index': [0, 1, 2, 0, 1],
# 'global_frame_index': [0, 1, 2, 3, 4],
# 'n_frames': [3, 3, 3, 2, 2],
# 'global_n_frames': [5, 5, 5, 5, 5]}


.. rubric:: References
.. footbibliography::
Expand Down
29 changes: 27 additions & 2 deletions testsuite/MDAnalysisTests/analysis/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,36 @@ def __init__(self, reader, **kwargs):

def _prepare(self):
self.results.found_frames = []
self.results.frame_index = []
self.results.global_frame_index = []
self.results.n_frames = []
self.results.global_n_frames = []

# self.n_frames is defined elsewhere
self.global_n_frames = len(self._trajectory[self._global_slicer])

def _single_frame(self):
frame_index = self._frame_index
global_frame_index = self._global_frame_index

self.results.found_frames.append(self._ts.frame)
self.results.frame_index.append(frame_index)
self.results.global_frame_index.append(global_frame_index)
self.results.n_frames.append(self.n_frames)
self.results.global_n_frames.append(self.global_n_frames)

def _conclude(self):
self.found_frames = list(self.results.found_frames)

def _get_aggregator(self):
return base.ResultsGroup(
{"found_frames": base.ResultsGroup.ndarray_hstack}
{
"found_frames": base.ResultsGroup.ndarray_hstack,
"frame_index": base.ResultsGroup.ndarray_hstack,
"global_frame_index": base.ResultsGroup.ndarray_hstack,
"n_frames": base.ResultsGroup.ndarray_hstack,
"global_n_frames": base.ResultsGroup.ndarray_hstack,
}
)


Expand Down Expand Up @@ -450,12 +470,17 @@ def test_frames_times(client_FrameAnalysis):
start=1, stop=8, step=2, **client_FrameAnalysis
)
frames = np.array([1, 3, 5, 7])
assert an.n_frames == len(frames)
n_frames = len(frames)
frame_indices = np.arange(n_frames)

assert an.n_frames == n_frames
assert_equal(an.found_frames, frames)
assert_equal(an.frames, frames, err_msg=FRAMES_ERR)
assert_allclose(
an.times, frames * 100, rtol=0, atol=1.5e-4, err_msg=TIMES_ERR
)
assert_equal(an.results.global_frame_index, frame_indices)
assert_equal(an.results.global_n_frames, [n_frames] * n_frames)


def test_verbose(u):
Expand Down
Loading