Skip to content

Commit

Permalink
feat: added StackedDistributedArray and StackedVStack
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Feb 13, 2024
1 parent b811a30 commit e38e6f8
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 29 deletions.
13 changes: 13 additions & 0 deletions examples/plot_distributed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,25 @@
pylops_mpi.plot_local_arrays(arr1, "Distributed Array - 1", vmin=0, vmax=1)
pylops_mpi.plot_local_arrays(arr2, "Distributed Array - 2", vmin=0, vmax=1)

###############################################################################
# **Scaling** - Each process operates on its local portion of
# the array and scales the corresponding elements by a given scalar.
scale_arr = .5 * arr1
pylops_mpi.plot_local_arrays(scale_arr, "Scaling", vmin=0, vmax=1)

###############################################################################
# **Element-wise Addition** - Each process operates on its local portion of
# the array and adds the corresponding elements together.
sum_arr = arr1 + arr2
pylops_mpi.plot_local_arrays(sum_arr, "Addition", vmin=0, vmax=1)

###############################################################################
# **Element-wise In-place Addition** - Similar to the previous one but the
# addition is performed directly on one of the addends without creating a new
# distributed array.
sum_arr += arr2
pylops_mpi.plot_local_arrays(sum_arr, "Addition", vmin=0, vmax=1)

###############################################################################
# **Element-wise Subtraction** - Each process operates on its local portion
# of the array and subtracts the corresponding elements together.
Expand Down
181 changes: 179 additions & 2 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,20 @@ def __neg__(self):
def __add__(self, x):
return self.add(x)

def __iadd__(self, x):
return self.iadd(x)

def __sub__(self, x):
return self.__add__(-x)

def __isub__(self, x):
return self.__iadd__(-x)

def __mul__(self, x):
return self.multiply(x)

def __rmul__(self, x):
return self.multiply(x)

def add(self, dist_array):
"""Distributed Addition of arrays
Expand All @@ -360,17 +369,32 @@ def add(self, dist_array):
SumArray[:] = self.local_array + dist_array.local_array
return SumArray

def iadd(self, dist_array):
"""Distributed In-place Addition of arrays
"""
self._check_partition_shape(dist_array)
self[:] = self.local_array + dist_array.local_array
return self


def multiply(self, dist_array):
"""Distributed Element-wise multiplication
"""
self._check_partition_shape(dist_array)
if isinstance(dist_array, DistributedArray):
self._check_partition_shape(dist_array)

ProductArray = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
dtype=self.dtype,
partition=self.partition,
local_shapes=self.local_shapes,
axis=self.axis)
ProductArray[:] = self.local_array * dist_array.local_array
if isinstance(dist_array, DistributedArray):
# multiply two DistributedArray
ProductArray[:] = self.local_array * dist_array.local_array
else:
# multiply with scalar
ProductArray[:] = self.local_array * dist_array
return ProductArray

def dot(self, dist_array):
Expand Down Expand Up @@ -557,3 +581,156 @@ def __repr__(self):
f"local shape={self.local_shape}" \
f", dtype={self.dtype}, " \
f"processes={[i for i in range(self.size)]})> "


class StackedDistributedArray:
r"""Stacked DistributedArrays
Stack DistributedArray objects and power them with basic mathematical operations.
This class allows one to work with a series of distributed arrays to avoid having to create
a single distributed array with some special internal sorting.
Parameters
----------
distarrays : :obj:`list`
List of :class:`pylops_mpi.DistributedArray` objects.
"""

def __init__(self, distarrays: List):
self.distarrays = distarrays
self.narrays = len(distarrays)

def __getitem__(self, index):
return self.distarrays[index]

def __setitem__(self, index, value):
self.distarrays[index][:] = value

def asarray(self):
"""Global view of the array
Gather all the distributed arrays
Returns
-------
final_array : :obj:`numpy.ndarray`
Global Array gathered at all ranks
"""
return np.hstack([distarr.asarray().ravel() for distarr in self.distarrays])

def _check_stacked_size(self, stacked_array):
"""Check that arrays have consistent size
"""
if self.narrays != stacked_array.narrays:
raise ValueError("Stacked arrays must be composed the same number of of distributed arrays")
for iarr in range(self.narrays):
if self.distarrays[iarr].global_shape != stacked_array[iarr].global_shape:
raise ValueError(f"Stacked arrays {iarr} have different global shape:"
f"{self.distarrays[iarr].global_shape} / "
f"{stacked_array[iarr].global_shape}")

def __neg__(self):
arr = self.copy() #StackedDistributedArray([distarray.copy() for distarray in self.distarrays])
for iarr in range(self.narrays):
arr[iarr][:] = -arr[iarr][:]
return arr

def __add__(self, x):
return self.add(x)

def __iadd__(self, x):
return self.iadd(x)

def __sub__(self, x):
return self.__add__(-x)

def __isub__(self, x):
return self.__iadd__(-x)

def __mul__(self, x):
return self.multiply(x)

def __rmul__(self, x):
return self.multiply(x)

def add(self, stacked_array):
"""Stacked Distributed Addition of arrays
"""
self._check_stacked_size(stacked_array)
SumArray = self.copy() #StackedDistributedArray([distarray.copy() for distarray in self.distarrays])
for iarr in range(self.narrays):
SumArray[iarr][:] = (self[iarr] + stacked_array[iarr])[:]
return SumArray

def iadd(self, stacked_array):
"""Stacked Distributed In-Place Addition of arrays
"""
self._check_stacked_size(stacked_array)
for iarr in range(self.narrays):
self[iarr][:] = (self[iarr] + stacked_array[iarr])[:]
return self

def multiply(self, stacked_array):
if isinstance(stacked_array, StackedDistributedArray):
self._check_stacked_size(stacked_array)
ProductArray = self.copy() #StackedDistributedArray([distarray.copy() for distarray in self.distarrays])

if isinstance(stacked_array, StackedDistributedArray):
# multiply two DistributedArray
for iarr in range(self.narrays):
ProductArray[iarr][:] = (self[iarr] * stacked_array[iarr])[:]
else:
# multiply with scalar
for iarr in range(self.narrays):
ProductArray[iarr][:] = (self[iarr] * stacked_array)[:]
return ProductArray

def dot(self, stacked_array):
self._check_stacked_size(stacked_array)
dotprod = 0.
for iarr in range(self.narrays):
dotprod += self[iarr].dot(stacked_array[iarr])
return dotprod

def norm(self, ord: Optional[int] = None):
"""numpy.linalg.norm method on stacked Distributed arrays
Parameters
----------
ord : :obj:`int`, optional
Order of the norm.
"""
norms = np.array([distarray.norm(ord) for distarray in self.distarrays])
ord = 2 if ord is None else ord
if ord in ['fro', 'nuc']:
raise ValueError(f"norm-{ord} not possible for vectors")
elif ord == 0:
# Count non-zero then sum reduction
norm = np.sum(norms)
elif ord == np.inf:
# Calculate max followed by max reduction
norm = np.max(norms)
elif ord == -np.inf:
# Calculate min followed by max reduction
norm = np.min(norms)
else:
norm = np.power(np.sum(np.power(norms, ord)), 1. / ord)
return norm

def conj(self):
"""Distributed conj() method
"""
ConjArray = StackedDistributedArray([distarray.conj() for distarray in self.distarrays])
return ConjArray

def copy(self):
"""Creates a copy of the DistributedArray
"""
arr = StackedDistributedArray([distarray.copy() for distarray in self.distarrays])
return arr

def __repr__(self):
repr_dist = "\n".join([distarray.__repr__() for distarray in self.distarrays])
return f"<StackedDistributedArray with {self.narrays} distributed arrays: \n" + repr_dist
2 changes: 1 addition & 1 deletion pylops_mpi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .DistributedArray import DistributedArray, Partition
from .DistributedArray import DistributedArray, Partition, StackedDistributedArray
from .LinearOperator import *
from .basicoperators import *
from . import (
Expand Down
52 changes: 51 additions & 1 deletion pylops_mpi/basicoperators/VStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pylops import LinearOperator
from pylops.utils import DTypeLike

from pylops_mpi import MPILinearOperator, DistributedArray, Partition
from pylops_mpi import MPILinearOperator, DistributedArray, Partition, StackedDistributedArray
from pylops_mpi.utils.decorators import reshaped


Expand Down Expand Up @@ -128,3 +128,53 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
y1 = np.sum(y1, axis=0)
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
return y


class StackedVStack():
r"""Stacked VStack Operator
Create a vertical stack of :class:`pylops_mpi.MPILinearOperator` operators.
Parameters
----------
ops : :obj:`list`
One or more :class:`pylops_mpi.MPILinearOperator` to be vertically stacked.
Attributes
----------
shape : :obj:`tuple`
Operator shape
Raises
------
ValueError
If ``ops`` have different number of columns
Notes
-----
An StackedVStack is composed of N :class:`pylops_mpi.MPILinearOperator` stacked
vertically. These MPI operators will be applied sequentially, however distributed
computations will be performed within each operator.
"""

def __init__(self, ops: Sequence[MPILinearOperator]):
self.ops = ops
if len(set(op.shape[1] for op in ops)) > 1:
raise ValueError("Operators have different number of columns")
self.shape = (np.sum(op.shape[0] for op in ops),
ops[0].shape[1])
self.dtype = _get_dtype(self.ops)

def matvec(self, x: DistributedArray) -> StackedDistributedArray:
y1 = []
for oper in self.ops:
y1.append(oper.matvec(x))
y = StackedDistributedArray(y1)
return y

def rmatvec(self, x: StackedDistributedArray) -> DistributedArray:
y = self.ops[0].rmatvec(x[0])
for xx, oper in zip(x[1:], self.ops[1:]):
y = y + oper.rmatvec(xx)
return y
10 changes: 7 additions & 3 deletions pylops_mpi/basicoperators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
using MPI.
A list of operators present in pylops_mpi.basicoperators :
MPIBlockDiag Block Diagonal Operator
MPIVStack Vertical Stacking
MPIHStack Horizontal Stacking
MPIBlockDiag Block Diagonal arrangment of PyLops operators
MPIVStack Vertical Stacking of PyLops operators
StackedVStack Vertical Stacking of PyLops-MPI operators
MPIHStack Horizontal Stacking of PyLops operators
MPIFirstDerivative First Derivative
MPISecondDerivative Second Derivative
MPILaplacian Laplacian
"""

Expand All @@ -24,6 +27,7 @@
__all__ = [
"MPIBlockDiag",
"MPIVStack",
"StackedVStack",
"MPIHStack",
"MPIFirstDerivative",
"MPISecondDerivative",
Expand Down
22 changes: 8 additions & 14 deletions pylops_mpi/optimization/cls_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def step(self, x: DistributedArray, show: bool = False) -> DistributedArray:
Opc = self.Op.matvec(self.c)
cOpc = np.abs(self.c.dot(Opc.conj()))
a = self.kold / cOpc
x[:] += a * self.c.local_array
self.r[:] -= a * Opc.local_array
x += a * self.c
self.r -= a * Opc
k = np.abs(self.r.dot(self.r.conj()))
b = k / self.kold
self.c[:] = self.r.local_array + b * self.c.local_array
self.c = self.r + b * self.c
self.kold = k
self.iiter += 1
self.cost.append(float(np.sqrt(self.kold)))
Expand Down Expand Up @@ -344,10 +344,7 @@ def setup(self,
else:
x = x0.copy()
self.s = self.y - self.Op.matvec(x)
damped_x = DistributedArray(global_shape=x.global_shape, dtype=x.dtype,
local_shapes=x.local_shapes,
partition=x.partition)
damped_x[:] = damp * x.local_array
damped_x = damp * x.local_array
r = self.Op.rmatvec(self.s) - damped_x
self.rank = x.rank
self.c = r.copy()
Expand Down Expand Up @@ -384,16 +381,13 @@ def step(self, x: DistributedArray, show: bool = False) -> DistributedArray:
"""

a = self.kold / (self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj()))
x[:] = x.local_array + a * self.c.local_array
self.s[:] = self.s.local_array - a * self.q.local_array
damped_x = DistributedArray(global_shape=x.global_shape, dtype=x.dtype,
local_shapes=x.local_shapes,
partition=x.partition)
damped_x[:] = self.damp * x.local_array
x += a * self.c
self.s -= a * self.q
damped_x = self.damp * x
r = self.Op.rmatvec(self.s) - damped_x
k = np.abs(r.dot(r.conj()))
b = k / self.kold
self.c[:] = r.local_array + b * self.c.local_array
self.c = r + b * self.c
self.q = self.Op.matvec(self.c)
self.kold = k
self.iiter += 1
Expand Down
File renamed without changes.
Loading

0 comments on commit e38e6f8

Please sign in to comment.