Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'MDAnalysis.analysis.align' parallelization #4738

Draft
wants to merge 21 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions package/MDAnalysis/analysis/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
from MDAnalysis.lib.log import ProgressBar
from ..due import due, Doi

from .base import AnalysisBase
from .base import AnalysisBase, ResultsGroup

logger = logging.getLogger('MDAnalysis.analysis.align')

Expand Down Expand Up @@ -678,6 +678,12 @@ class AlignTraj(AnalysisBase):

"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return ("serial", "dask")

def __init__(self, mobile, reference, select='all', filename=None,
prefix='rmsfit_', weights=None,
tol_mass=0.1, match_atoms=True, strict=False, force=True, in_memory=False,
Expand Down Expand Up @@ -854,7 +860,8 @@ def _single_frame(self):
self._writer.write(mobile_atoms)

def _conclude(self):
self._writer.close()
if self._writer:
self._writer.close()
if not self._verbose:
logging.disable(logging.NOTSET)

Expand All @@ -866,6 +873,12 @@ def rmsd(self):
warnings.warn(wmsg, DeprecationWarning)
return self.results.rmsd

def _get_aggregator(self):
return ResultsGroup(
lookup={
"rmsd": ResultsGroup.ndarray_hstack,
}
)

class AverageStructure(AnalysisBase):
"""RMS-align trajectory to a reference structure using a selection,
Expand Down Expand Up @@ -896,6 +909,12 @@ class AverageStructure(AnalysisBase):

"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return ("serial", "multiprocessing", "dask")

def __init__(self, mobile, reference=None, select='all', filename=None,
weights=None,
tol_mass=0.1, match_atoms=True, strict=False, force=True, in_memory=False,
Expand Down Expand Up @@ -1089,6 +1108,15 @@ def _conclude(self):
if not self._verbose:
logging.disable(logging.NOTSET)

def _get_aggregator(self):
return ResultsGroup(
lookup={
"universe": ResultsGroup.ndarray_vstack,
"positions": ResultsGroup.ndarray_vstack,
"rmsd": ResultsGroup.ndarray_vstack,
}
)

@property
def universe(self):
wmsg = ("The `universe` attribute was deprecated in MDAnalysis 2.0.0 "
Expand Down
13 changes: 13 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from MDAnalysis.analysis.hydrogenbonds.hbond_analysis import (
HydrogenBondAnalysis,
)
from MDAnalysis.analysis.align import AverageStructure, AlignTraj
from MDAnalysis.analysis.nucleicacids import NucPairDist
from MDAnalysis.analysis.contacts import Contacts
from MDAnalysis.analysis.density import DensityAnalysis
Expand Down Expand Up @@ -154,6 +155,18 @@ def client_HydrogenBondAnalysis(request):
return request.param


# MDAnalysis.analysis.align

@pytest.fixture(scope="module", params=params_for_cls(AverageStructure))
def client_AverageStructure(request):
return request.param


@pytest.fixture(scope="module", params=params_for_cls(AlignTraj))
def client_AlignTraj(request):
return request.param


# MDAnalysis.analysis.nucleicacids


Expand Down
18 changes: 9 additions & 9 deletions testsuite/MDAnalysisTests/analysis/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,11 @@ def test_AlignTraj_outfile_default(self, universe, reference, tmpdir):
x._writer.close()

def test_AlignTraj_outfile_default_exists(
self, universe, reference, tmpdir
self, universe, reference, tmpdir, client_AlignTraj
):
reference.trajectory[-1]
outfile = str(tmpdir.join("align_test.dcd"))
align.AlignTraj(universe, reference, filename=outfile).run()
align.AlignTraj(universe, reference, filename=outfile).run(**client_AlignTraj)
fitted = mda.Universe(PSF, outfile)

# ensure default file exists
Expand All @@ -324,13 +324,13 @@ def test_AlignTraj_outfile_default_exists(
with pytest.raises(IOError):
align.AlignTraj(fitted, reference, force=False)

def test_AlignTraj_step_works(self, universe, reference, tmpdir):
def test_AlignTraj_step_works(self, universe, reference, tmpdir, client_AlignTraj):
reference.trajectory[-1]
outfile = str(tmpdir.join("align_test.dcd"))
# this shouldn't throw an exception
align.AlignTraj(universe, reference, filename=outfile).run(step=10)

def test_AlignTraj_deprecated_attribute(self, universe, reference, tmpdir):
def test_AlignTraj_deprecated_attribute(self, universe, reference, tmpdir, client_AlignTraj):
reference.trajectory[-1]
outfile = str(tmpdir.join("align_test.dcd"))
x = align.AlignTraj(universe, reference, filename=outfile).run(stop=2)
Expand All @@ -339,7 +339,7 @@ def test_AlignTraj_deprecated_attribute(self, universe, reference, tmpdir):
with pytest.warns(DeprecationWarning, match=wmsg):
assert_equal(x.rmsd, x.results.rmsd)

def test_AlignTraj(self, universe, reference, tmpdir):
def test_AlignTraj(self, universe, reference, tmpdir, client_AlignTraj):
reference.trajectory[-1]
outfile = str(tmpdir.join("align_test.dcd"))
x = align.AlignTraj(universe, reference, filename=outfile).run()
Expand All @@ -354,7 +354,7 @@ def test_AlignTraj(self, universe, reference, tmpdir):
self._assert_rmsd(reference, fitted, 0, 6.929083044751061)
self._assert_rmsd(reference, fitted, -1, 0.0)

def test_AlignTraj_weighted(self, universe, reference, tmpdir):
def test_AlignTraj_weighted(self, universe, reference, tmpdir, client_AlignTraj):
outfile = str(tmpdir.join("align_test.dcd"))
x = align.AlignTraj(
universe, reference, filename=outfile, weights="mass"
Expand All @@ -374,7 +374,7 @@ def test_AlignTraj_weighted(self, universe, reference, tmpdir):
weights=universe.atoms.masses,
)

def test_AlignTraj_custom_weights(self, universe, reference, tmpdir):
def test_AlignTraj_custom_weights(self, universe, reference, tmpdir, client_AlignTraj):
weights = np.zeros(universe.atoms.n_atoms)
ca = universe.select_atoms("name CA")
weights[ca.indices] = 1
Expand Down Expand Up @@ -507,7 +507,7 @@ def test_alignto_partial_universe(self, universe, reference):
)


def _get_aligned_average_positions(ref_files, ref, select="all", **kwargs):
def _get_aligned_average_positions(ref_files, ref, select="all", **kwargs, ):
u = mda.Universe(*ref_files, in_memory=True)
prealigner = align.AlignTraj(u, ref, select=select, **kwargs).run()
ag = u.select_atoms(select)
Expand All @@ -530,7 +530,7 @@ def reference(self):

def test_average_structure_deprecated_attrs(self, universe, reference):
# Issue #3278 - remove in MDAnalysis 3.0.0
avg = align.AverageStructure(universe, reference).run(stop=2)
avg = align.AverageStructure(universe, reference).run(stop=2,)

wmsg = "The `universe` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
Expand Down
Loading