From cc699df7e33781c4fd9353f1884c05966d204c8f Mon Sep 17 00:00:00 2001 From: mrava87 Date: Wed, 23 Oct 2024 11:39:39 +0300 Subject: [PATCH] feat: added mask to DistributedArray --- pylops_mpi/DistributedArray.py | 82 ++++++++++++++++++++++---- pylops_mpi/basicoperators/BlockDiag.py | 12 +++- pylops_mpi/waveeqprocessing/MDC.py | 10 ++-- 3 files changed, 86 insertions(+), 18 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 6e5a471..7cbb74b 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -5,6 +5,7 @@ from enum import Enum from pylops.utils import DTypeLike, NDArray +from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.backend import get_module, get_array_module, get_module_name @@ -78,7 +79,10 @@ class DistributedArray: axis : :obj:`int`, optional Axis along which distribution occurs. Defaults to ``0``. local_shapes : :obj:`list`, optional - List of tuples representing local shapes at each rank. + List of tuples or integers representing local shapes at each rank. + mask : :obj:`list`, optional + Mask defining subsets of ranks to consider when performing 'global' + operations on the distributed array such as dot product or norm. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) dtype : :obj:`str`, optional @@ -88,7 +92,8 @@ class DistributedArray: def __init__(self, global_shape: Union[Tuple, Integral], base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD, partition: Partition = Partition.SCATTER, axis: int = 0, - local_shapes: Optional[List[Tuple]] = None, + local_shapes: Optional[List[Union[Tuple, Integral]]] = None, + mask: Optional[List[Integral]] = None, engine: Optional[str] = "numpy", dtype: Optional[DTypeLike] = np.float64): if isinstance(global_shape, Integral): @@ -100,10 +105,14 @@ def __init__(self, global_shape: Union[Tuple, Integral], raise ValueError(f"Should be either {Partition.BROADCAST} " f"or {Partition.SCATTER}") self.dtype = dtype - self._global_shape = global_shape + self._global_shape = _value_or_sized_to_tuple(global_shape) self._base_comm = base_comm self._partition = partition self._axis = axis + self._mask = mask + self._sub_comm = base_comm if mask is None else base_comm.Split(color=mask[base_comm.rank], key=base_comm.rank) + + local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] self._check_local_shapes(local_shapes) self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm, partition, axis) @@ -165,6 +174,16 @@ def local_shape(self): """ return self._local_shape + @property + def mask(self): + """Mask of the Distributed array + + Returns + ------- + engine : :obj:`list` + """ + return self._mask + @property def engine(self): """Engine of the Distributed array @@ -246,6 +265,16 @@ def local_shapes(self): """ return self.base_comm.allgather(self.local_shape) + @property + def sub_comm(self): + """MPI Sub-Communicator + + Returns + ------- + sub_comm : :obj:`MPI.Comm` + """ + return self._sub_comm + def asarray(self): """Global view of the array @@ -269,7 +298,8 @@ def to_dist(cls, x: NDArray, base_comm: MPI.Comm = MPI.COMM_WORLD, partition: Partition = Partition.SCATTER, axis: int = 0, - local_shapes: Optional[List[Tuple]] = None): + local_shapes: Optional[List[Tuple]] = None, + mask: Optional[List[Integral]] = None): """Convert A Global Array to a Distributed Array Parameters @@ -284,6 +314,9 @@ def to_dist(cls, x: NDArray, Axis of Distribution local_shapes : :obj:`list`, optional Local Shapes at each rank. + mask : :obj:`list`, optional + Mask defining subsets of ranks to consider when performing 'global' + operations on the distributed array such as dot product or norm. Returns ---------- @@ -295,6 +328,7 @@ def to_dist(cls, x: NDArray, partition=partition, axis=axis, local_shapes=local_shapes, + mask=mask, engine=get_module_name(get_array_module(x)), dtype=x.dtype) if partition == Partition.BROADCAST: @@ -336,6 +370,12 @@ def _check_partition_shape(self, dist_array): raise ValueError(f"Local Array Shape Mismatch - " f"{self.local_shape} != {dist_array.local_shape}") + def _check_mask(self, dist_array): + """Check mask of the Array + """ + if not np.array_equal(self.mask, dist_array.mask): + raise ValueError("Mask of both the arrays must be same") + def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): """MPI Allreduce operation """ @@ -345,12 +385,22 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): self.base_comm.Allreduce(send_buf, recv_buf, op) return recv_buf + def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + """MPI Allreduce operation with subcommunicator + """ + if recv_buf is None: + return self.sub_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.sub_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + def __neg__(self): arr = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, partition=self.partition, axis=self.axis, local_shapes=self.local_shapes, + mask=self.mask, engine=self.engine, dtype=self.dtype) arr[:] = -self.local_array @@ -378,11 +428,13 @@ def add(self, dist_array): """Distributed Addition of arrays """ self._check_partition_shape(dist_array) + self._check_mask(dist_array) SumArray = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, dtype=self.dtype, partition=self.partition, local_shapes=self.local_shapes, + mask=self.mask, engine=self.engine, axis=self.axis) SumArray[:] = self.local_array + dist_array.local_array @@ -392,6 +444,7 @@ def iadd(self, dist_array): """Distributed In-place Addition of arrays """ self._check_partition_shape(dist_array) + self._check_mask(dist_array) self[:] = self.local_array + dist_array.local_array return self @@ -400,12 +453,14 @@ def multiply(self, dist_array): """ if isinstance(dist_array, DistributedArray): self._check_partition_shape(dist_array) + self._check_mask(dist_array) ProductArray = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, dtype=self.dtype, partition=self.partition, local_shapes=self.local_shapes, + mask=self.mask, engine=self.engine, axis=self.axis) if isinstance(dist_array, DistributedArray): @@ -420,13 +475,15 @@ def dot(self, dist_array): """Distributed Dot Product """ self._check_partition_shape(dist_array) + self._check_mask(dist_array) + # Convert to Partition.SCATTER if Partition.BROADCAST x = DistributedArray.to_dist(x=self.local_array) \ if self.partition is Partition.BROADCAST else self y = DistributedArray.to_dist(x=dist_array.local_array) \ if self.partition is Partition.BROADCAST else dist_array # Flatten the local arrays and calculate dot product - return self._allreduce(np.dot(x.local_array.flatten(), y.local_array.flatten())) + return self._allreduce_subcomm(np.dot(x.local_array.flatten(), y.local_array.flatten())) def _compute_vector_norm(self, local_array: NDArray, axis: int, ord: Optional[int] = None): @@ -453,20 +510,20 @@ def _compute_vector_norm(self, local_array: NDArray, raise ValueError(f"norm-{ord} not possible for vectors") elif ord == 0: # Count non-zero then sum reduction - recv_buf = self._allreduce(np.count_nonzero(local_array, axis=axis).astype(np.float64)) + recv_buf = self._allreduce_subcomm(np.count_nonzero(local_array, axis=axis).astype(np.float64)) elif ord == np.inf: # Calculate max followed by max reduction - recv_buf = self._allreduce(np.max(np.abs(local_array), axis=axis).astype(np.float64), - recv_buf, op=MPI.MAX) + recv_buf = self._allreduce_subcomm(np.max(np.abs(local_array), axis=axis).astype(np.float64), + recv_buf, op=MPI.MAX) recv_buf = np.squeeze(recv_buf, axis=axis) elif ord == -np.inf: # Calculate min followed by min reduction - recv_buf = self._allreduce(np.min(np.abs(local_array), axis=axis).astype(np.float64), - recv_buf, op=MPI.MIN) + recv_buf = self._allreduce_subcomm(np.min(np.abs(local_array), axis=axis).astype(np.float64), + recv_buf, op=MPI.MIN) recv_buf = np.squeeze(recv_buf, axis=axis) else: - recv_buf = self._allreduce(np.sum(np.abs(np.float_power(local_array, ord)), axis=axis)) + recv_buf = self._allreduce_subcomm(np.sum(np.abs(np.float_power(local_array, ord)), axis=axis)) recv_buf = np.power(recv_buf, 1. / ord) return recv_buf @@ -500,6 +557,7 @@ def conj(self): partition=self.partition, axis=self.axis, local_shapes=self.local_shapes, + mask=self.mask, engine=self.engine, dtype=self.dtype) conj[:] = self.local_array.conj() @@ -513,6 +571,7 @@ def copy(self): partition=self.partition, axis=self.axis, local_shapes=self.local_shapes, + mask=self.mask, engine=self.engine, dtype=self.dtype) arr[:] = self.local_array @@ -535,6 +594,7 @@ def ravel(self, order: Optional[str] = "C"): local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes] arr = DistributedArray(global_shape=np.prod(self.global_shape), local_shapes=local_shapes, + mask=self.mask, partition=self.partition, engine=self.engine, dtype=self.dtype) diff --git a/pylops_mpi/basicoperators/BlockDiag.py b/pylops_mpi/basicoperators/BlockDiag.py index 7911692..74596e7 100644 --- a/pylops_mpi/basicoperators/BlockDiag.py +++ b/pylops_mpi/basicoperators/BlockDiag.py @@ -1,7 +1,8 @@ import numpy as np from scipy.sparse.linalg._interface import _get_dtype from mpi4py import MPI -from typing import Optional, Sequence +from typing import Optional, Sequence, Union, List +from numbers import Integral from pylops import LinearOperator from pylops.utils import DTypeLike @@ -28,6 +29,9 @@ class MPIBlockDiag(MPILinearOperator): One or more :class:`pylops.LinearOperator` to be stacked. base_comm : :obj:`mpi4py.MPI.Comm`, optional Base MPI Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + mask : :obj:`list`, optional + Mask defining subsets of ranks to consider when performing 'global' operations on + the distributed array such as dot product or norm. dtype : :obj:`str`, optional Type of elements in input array. @@ -95,8 +99,10 @@ class MPIBlockDiag(MPILinearOperator): def __init__(self, ops: Sequence[LinearOperator], base_comm: MPI.Comm = MPI.COMM_WORLD, + mask: Optional[List[Integral]] = None, dtype: Optional[DTypeLike] = None): self.ops = ops + self.mask = mask mops = np.zeros(len(self.ops), dtype=np.int64) nops = np.zeros(len(self.ops), dtype=np.int64) for iop, oper in enumerate(self.ops): @@ -116,7 +122,7 @@ def __init__(self, ops: Sequence[LinearOperator], def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, - engine=x.engine, dtype=self.dtype) + mask=self.mask, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.matvec(x.local_array[self.mmops[iop]: @@ -128,7 +134,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m, - engine=x.engine, dtype=self.dtype) + mask=self.mask, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: diff --git a/pylops_mpi/waveeqprocessing/MDC.py b/pylops_mpi/waveeqprocessing/MDC.py index 1d6524d..18980bf 100644 --- a/pylops_mpi/waveeqprocessing/MDC.py +++ b/pylops_mpi/waveeqprocessing/MDC.py @@ -20,7 +20,7 @@ def _MDC(G, nt, nv, nfmax, dt=1., dr=1., twosided=True, Used to be able to provide operators from different libraries to MDC. It operates in the same way as public method - (PoststackLinearModelling) but has additional input parameters allowing + (MPIMDC) but has additional input parameters allowing passing a different operator and additional arguments to be passed to such operator. @@ -81,8 +81,10 @@ def MPIMDC(G, nt, nv, nfreq, dt=1., dr=1., twosided=True, base_comm: MPI.Comm = MPI.COMM_WORLD): r"""Multi-dimensional convolution. - Apply multi-dimensional convolution between two datasets. Model and data - should be provided after flattening 2- or 3-dimensional arrays of size + Apply multi-dimensional convolution between two datasets in a distributed + fashion, with ``G`` distributed over ranks across the frequency axis. + Model and data are broadcasted and should be provided after flattening + 2- or 3-dimensional arrays of size :math:`[n_t \times n_r (\times n_{vs})]` and :math:`[n_t \times n_s (\times n_{vs})]` (or :math:`2*n_t-1` for ``twosided=True``), respectively. @@ -91,7 +93,7 @@ def MPIMDC(G, nt, nv, nfreq, dt=1., dr=1., twosided=True, ---------- G : :obj:`numpy.ndarray` Multi-dimensional convolution kernel in frequency domain of size - :math:`[n_{fmax} \times n_s \times n_r]` + :math:`[n_{f,rank} \times n_s \times n_r]` nt : :obj:`int` Number of samples along time axis for model and data (note that this must be equal to ``2*n_t-1`` when working with ``twosided=True``.