diff --git a/package/CHANGELOG b/package/CHANGELOG index 7b5ce3b4c8..3cdec6361b 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -15,7 +15,7 @@ The rules for this file: ------------------------------------------------------------------------------- ??/??/?? IAlibay, ChiahsinChu, RMeli, tanishy7777, talagayev, tylerjereddy, - marinegor + marinegor, yuxuanzhuang * 2.9.0 @@ -38,6 +38,11 @@ Enhancements Changes * Changed `fasteners` dependency to `filelock` (Issue #4797, PR #4800) * Codebase is now formatted with black (version `24`) (PR #4886) + * Added _global_slicer to make it possible to retrieve information + of the full trajectory slice in AnalysisBase. (Issue #4891 PR #4892) + * Added _global_frame_index to to keep track of frame iteration + number for the full trajectory slice in `single_frame` in AnalysisBase + (Issue #4891 PR #4892) Deprecations diff --git a/package/MDAnalysis/analysis/base.py b/package/MDAnalysis/analysis/base.py index 675c6d6967..5c6dad6aa9 100644 --- a/package/MDAnalysis/analysis/base.py +++ b/package/MDAnalysis/analysis/base.py @@ -441,8 +441,13 @@ def _setup_frames( each of the workers and gets executed twice: one time in :meth:`_setup_frames` for the whole trajectory, second time in :meth:`_compute` for each of the computation groups. + + .. versionchanged:: 2.9.0 + Add `self._global_slicer` attribute to store the slicer for the + whole trajectory. """ slicer = self._define_run_frames(trajectory, start, stop, step, frames) + self._global_slicer = slicer self._prepare_sliced_trajectory(slicer) def _single_frame(self): @@ -452,6 +457,9 @@ def _single_frame(self): Attributes accessible during your calculations: - ``self._frame_index``: index of the frame in results array + Note that this is not the same as the frame number in the trajectory + - ``self._global_frame_index``: index of the frame in the trajectory + This is useful for parallel runs, where you can't rely on the - ``self._ts`` -- Timestep instance - ``self._sliced_trajectory`` -- trajectory that you're iterating over - ``self.results`` -- :class:`MDAnalysis.analysis.results.Results` instance @@ -537,6 +545,7 @@ def _compute( ) ): self._frame_index = idx # accessed later by subclasses + self._global_frame_index = indexed_frames[idx, 0] self._ts = ts self.frames[idx] = ts.frame self.times[idx] = ts.time diff --git a/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst b/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst index 3070614b5a..3c741fd6c4 100644 --- a/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst +++ b/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst @@ -118,22 +118,29 @@ For MDAnalysis developers From a developer point of view, there are a few methods that are important in order to understand how parallelization is implemented: -#. :meth:`MDAnalysis.analysis.base.AnalysisBase._define_run_frames` +#. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_frames` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._prepare_sliced_trajectory` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._configure_backend` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_computation_groups` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._compute` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._get_aggregator` -The first two methods share the functionality of :meth:`_setup_frames`. -:meth:`_define_run_frames` is run once during analysis, as it checks that input -parameters `start`, `stop`, `step` or `frames` are consistent with the given -trajectory and prepares the ``slicer`` object that defines the iteration -pattern through the trajectory. :meth:`_prepare_sliced_trajectory` assigns to +:meth:`_setup_frames` is run once during analysis :attr:`run()`, as it checks that input +parameters :attr:`start`, :attr:`stop`, :attr:`step` or :attr:`frames` are consistent with the given +trajectory and prepares the :attr:`slicer` object that defines the iteration +pattern through the trajectory with :meth:`_define_run_frames`. +The attribute :attr:`self._global_slicer` is assigned based on the `slicer`. +Users can later access the full sliced trajectory being analyzed via +:attr:`self._trajectory[self._global_slicer]`. + +:meth:`_prepare_sliced_trajectory` assigns to the :attr:`self._sliced_trajectory` attribute, computes the number of frames in it, and fills the :attr:`self.frames` and :attr:`self.times` arrays. In case the computation will be later split between other processes, this method will -be called again on each of the computation groups. +be called again on each of the computation groups. In parallel analysis, +:attr:`self._sliced_trajectory` represents a split of the original sliced +trajectory, and :attr:`self.n_frames` is the number of frames in each split +computation group (not the total number of frames in the sliced trajectory). The method :meth:`_configure_backend` performs basic health checks for a given analysis class -- namely, it compares a given backend (if it's a :class:`str` @@ -155,7 +162,13 @@ analysis get initialized with the :meth:`_prepare` method. Then the function iterates over :attr:`self._sliced_trajectory`, assigning :attr:`self._frame_index` and :attr:`self._ts` as frame index (within a computation group) and timestamp, and also setting respective -:attr:`self.frames` and :attr:`self.times` array values. +:attr:`self.frames` and :attr:`self.times` array values. Additionally, +:attr:`self._global_frame_index` is assigned the global frame index +within the full sliced trajectory (:attr:`self._trajectory[self._global_slicer]`). +This global frame index is particularly useful for analyses requiring it, such as +:class:`MDAnalysis.analysis.diffusionmap.DistanceMatrix` that needs to know the +frame index in the full trajectory. +See :ref:`retrieving-correct-frame-index` for more details. After :meth:`_compute` has finished, the main analysis instance calls the :meth:`_get_aggregator` method, which merges the :attr:`self.results` @@ -357,6 +370,82 @@ In this way, you will override the check for supported backends. with a supported backend. When reporting *always mention if you used* ``unsupported_backend=True``. +.. _retrieving-correct-frame-index: +Retrieving correct frame index in parallel analysis +=================================================== + +To retrieve the correct frame index during parallel analysis, use the +:attr:`self._global_frame_index` attribute. This attribute represents the global +frame index within the full sliced trajectory +(:attr:`self._trajectory[self._global_slicer]`). + +For an example illustrating when to use :attr:`_frame_index` versus +:attr:`_global_frame_index` and :attr:`self._global_slicer`, +see the following code snippet: + +.. code-block:: python + + from MDAnalysis.analysis.base import AnalysisBase + from MDAnalysis.analysis.results import ResultsGroup + + class MyAnalysis(AnalysisBase): + _analysis_algorithm_is_parallelizable = True + + @classmethod + def get_supported_backends(cls): + """Define the supported backends for the analysis.""" + return ('serial', 'multiprocessing', 'dask') + + def _prepare(self): + """Initialize result attributes and compute global frame count.""" + self.results.frame_index = [] + self.results.global_frame_index = [] + self.results.n_frames = [] + self.results.global_n_frames = [] + self.global_n_frames = len(self._trajectory[self._global_slicer]) + + def _single_frame(self): + """Process a single frame during the analysis.""" + frame_index = self._frame_index + global_frame_index = self._global_frame_index + + # Append results for the current frame + self.results.frame_index.append(frame_index) + self.results.global_frame_index.append(global_frame_index) + self.results.n_frames.append(self.n_frames) + self.results.global_n_frames.append(self.global_n_frames) + + def _get_aggregator(self): + """Return an aggregator to combine results from multiple workers.""" + return ResultsGroup( + lookup={ + 'frame_index': ResultsGroup.flatten_sequence, + 'global_frame_index': ResultsGroup.flatten_sequence, + 'n_frames': ResultsGroup.flatten_sequence, + 'global_n_frames': ResultsGroup.flatten_sequence, + } + ) + + # Example usage: serial analysis + ana = MyAnalysis(u.trajectory) + ana.run(step=2) + print(ana.results) + # Output: + # {'frame_index': [0, 1, 2, 3, 4], + # 'global_frame_index': [0, 1, 2, 3, 4], + # 'n_frames': [5, 5, 5, 5, 5], + # 'global_n_frames': [5, 5, 5, 5, 5]} + + # Example usage: parallel analysis + ana = MyAnalysis(u.trajectory) + ana.run(step=2, backend='dask', n_workers=2) + print(ana.results) + # Output: + # {'frame_index': [0, 1, 2, 0, 1], + # 'global_frame_index': [0, 1, 2, 3, 4], + # 'n_frames': [3, 3, 3, 2, 2], + # 'global_n_frames': [5, 5, 5, 5, 5]} + .. rubric:: References .. footbibliography:: diff --git a/testsuite/MDAnalysisTests/analysis/test_base.py b/testsuite/MDAnalysisTests/analysis/test_base.py index e369c4c602..9cc59bb21e 100644 --- a/testsuite/MDAnalysisTests/analysis/test_base.py +++ b/testsuite/MDAnalysisTests/analysis/test_base.py @@ -48,16 +48,36 @@ def __init__(self, reader, **kwargs): def _prepare(self): self.results.found_frames = [] + self.results.frame_index = [] + self.results.global_frame_index = [] + self.results.n_frames = [] + self.results.global_n_frames = [] + + # self.n_frames is defined elsewhere + self.global_n_frames = len(self._trajectory[self._global_slicer]) def _single_frame(self): + frame_index = self._frame_index + global_frame_index = self._global_frame_index + self.results.found_frames.append(self._ts.frame) + self.results.frame_index.append(frame_index) + self.results.global_frame_index.append(global_frame_index) + self.results.n_frames.append(self.n_frames) + self.results.global_n_frames.append(self.global_n_frames) def _conclude(self): self.found_frames = list(self.results.found_frames) def _get_aggregator(self): return base.ResultsGroup( - {"found_frames": base.ResultsGroup.ndarray_hstack} + { + "found_frames": base.ResultsGroup.ndarray_hstack, + "frame_index": base.ResultsGroup.ndarray_hstack, + "global_frame_index": base.ResultsGroup.ndarray_hstack, + "n_frames": base.ResultsGroup.ndarray_hstack, + "global_n_frames": base.ResultsGroup.ndarray_hstack, + } ) @@ -450,12 +470,17 @@ def test_frames_times(client_FrameAnalysis): start=1, stop=8, step=2, **client_FrameAnalysis ) frames = np.array([1, 3, 5, 7]) - assert an.n_frames == len(frames) + n_frames = len(frames) + frame_indices = np.arange(n_frames) + + assert an.n_frames == n_frames assert_equal(an.found_frames, frames) assert_equal(an.frames, frames, err_msg=FRAMES_ERR) assert_allclose( an.times, frames * 100, rtol=0, atol=1.5e-4, err_msg=TIMES_ERR ) + assert_equal(an.results.global_frame_index, frame_indices) + assert_equal(an.results.global_n_frames, [n_frames] * n_frames) def test_verbose(u):