Skip to content

Commit

Permalink
Refactor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 26, 2024
1 parent 8f81fa5 commit 5244526
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 59 deletions.
2 changes: 1 addition & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def _initialize(uniform_initializer, shapes: Tuple[Shape]) -> Tensor:
shape = concat_shapes(*shapes)
assert shape.well_defined, f"When creating a Tensor, shape needs to have definitive sizes but got {shape}"
if shape.is_non_uniform:
stack_dim = shape.shape.without('dims')[0:1]
stack_dim = shape.non_uniform_shape[0]
shapes = shape.unstack(stack_dim.name)
tensors = [_initialize(uniform_initializer, s) for s in shapes]
return stack_tensors(tensors, stack_dim)
Expand Down
117 changes: 71 additions & 46 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,11 @@ def with_sizes(self, sizes: Union[Sequence[int], Sequence[Tuple[str, ...]], 'Sha
def without_sizes(self):
return Dim(self.name, None, self.dim_type, None)

def flipped(self, dims: Union[List[str], Tuple[str]]):
if self.slice_names is None or self.name not in dims:
return self
return Dim(self.name, self.size, self.dim_type, self.slice_names[::-1])

def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'):
assert self.name == parse_dim_names(dims, 1)[0]
return self._replace(new)
Expand Down Expand Up @@ -1329,6 +1334,7 @@ def __getitem__(self, selection):
elif isinstance(selection, Shape):
selection = selection.names
if isinstance(selection, (tuple, list)):
selection = [self.names[sel] if isinstance(sel, int) else sel for sel in selection]
if DEBUG_CHECKS:
assert all(isinstance(s, str) for s in selection)
return concat_shapes_(*[self.dims[n] for n in selection])
Expand Down Expand Up @@ -1475,6 +1481,10 @@ def with_sizes(self, sizes: Union[Sequence[int], Sequence[Tuple[str, ...]], 'Sha
def without_sizes(self):
return PureShape(self.dim_type, {n: dim.without_sizes() for n, dim in self.dims.items()})

def flipped(self, dims: Union[List[str], Tuple[str]]):
dims = {n: dim.flipped(dims) for n, dim in self.dims.items()}
return PureShape(self.dim_type, dims)

def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'):
dims = parse_dim_order(dims)
dim_list = list(self.dims.values())
Expand Down Expand Up @@ -1636,32 +1646,32 @@ def channel_rank(self) -> int:
@property
def primal(self):
dims = {**self.instance.dims, **self.spatial.dims, **self.channel.dims}
return MixedShape(EMPTY_SHAPE, EMPTY_SHAPE, self.instance, self.spatial, self.channel, dims)
return MixedShape(EMPTY_SHAPE, EMPTY_SHAPE, self.instance, self.spatial, self.channel, dims) if dims else EMPTY_SHAPE
@property
def non_primal(self):
dims = {**self.batch.dims, **self.dual.dims}
return MixedShape(self.batch, self.dual, EMPTY_SHAPE, EMPTY_SHAPE, EMPTY_SHAPE, dims)
return MixedShape(self.batch, self.dual, EMPTY_SHAPE, EMPTY_SHAPE, EMPTY_SHAPE, dims) if dims else EMPTY_SHAPE

@property
def non_batch(self):
dims = {n: dim for n, dim in self.dims.items() if dim.dim_type != BATCH_DIM}
return MixedShape(EMPTY_SHAPE, self.dual, self.instance, self.spatial, self.channel, dims)
return MixedShape(EMPTY_SHAPE, self.dual, self.instance, self.spatial, self.channel, dims) if dims else EMPTY_SHAPE
@property
def non_dual(self):
dims = {n: dim for n, dim in self.dims.items() if dim.dim_type != DUAL_DIM}
return MixedShape(self.batch, EMPTY_SHAPE, self.instance, self.spatial, self.channel, dims)
return MixedShape(self.batch, EMPTY_SHAPE, self.instance, self.spatial, self.channel, dims) if dims else EMPTY_SHAPE
@property
def non_instance(self):
dims = {n: dim for n, dim in self.dims.items() if dim.dim_type != INSTANCE_DIM}
return MixedShape(self.batch, self.dual, EMPTY_SHAPE, self.spatial, self.channel, dims)
return MixedShape(self.batch, self.dual, EMPTY_SHAPE, self.spatial, self.channel, dims) if dims else EMPTY_SHAPE
@property
def non_spatial(self):
dims = {n: dim for n, dim in self.dims.items() if dim.dim_type != SPATIAL_DIM}
return MixedShape(self.batch, self.dual, self.instance, EMPTY_SHAPE, self.channel, dims)
return MixedShape(self.batch, self.dual, self.instance, EMPTY_SHAPE, self.channel, dims) if dims else EMPTY_SHAPE
@property
def non_channel(self):
dims = {n: dim for n, dim in self.dims.items() if dim.dim_type != CHANNEL_DIM}
return MixedShape(self.batch, self.dual, self.instance, self.spatial, EMPTY_SHAPE, dims)
return MixedShape(self.batch, self.dual, self.instance, self.spatial, EMPTY_SHAPE, dims) if dims else EMPTY_SHAPE

def __repr__(self):
return '(' + ', '.join([repr(dim)[1:-1] for dim in self.dims.values()]) + ')'
Expand Down Expand Up @@ -1714,6 +1724,7 @@ def __getitem__(self, selection):
elif isinstance(selection, Shape):
selection = selection.names
if isinstance(selection, (tuple, list)):
selection = [self.names[sel] if isinstance(sel, int) else sel for sel in selection]
if DEBUG_CHECKS:
assert all(isinstance(s, str) for s in selection)
return concat_shapes_(*[self.dims[n] for n in selection])
Expand Down Expand Up @@ -1853,6 +1864,9 @@ def with_sizes(self, sizes: Union[Sequence[int], Sequence[Tuple[str, ...]], 'Sha
def without_sizes(self):
return concat_shapes_(*[dim.without_sizes() for dim in self.dims.values()])

def flipped(self, dims: Union[List[str], Tuple[str]]):
return concat_shapes_(*[dim.flipped(dims) for n, dim in self.dims.items()])

def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'):
dims = parse_dim_order(dims)
dim_list = list(self.dims.values())
Expand Down Expand Up @@ -2565,8 +2579,13 @@ def concat_shapes_(*shapes: Shape) -> Shape:

def shape_stack(stack_dim: Shape, *shapes: Shape, stack_dim_first=True):
""" Returns the shape of a tensor created by stacking tensors with `shapes`. """
assert stack_dim.rank == 1, f"stack_dim must be a single dim but got {stack_dim}"
stack_dim = Dim(stack_dim.name, len(shapes), stack_dim.dim_type, stack_dim.item_names[0])
if stack_dim.rank > 1:
assert stack_dim.volume == len(shapes), f"stack_dim {stack_dim} does not match number of shapes: {len(shapes)}"
elif len(stack_dim) == 1:
stack_dim = Dim(stack_dim.name, len(shapes), stack_dim.dim_type, stack_dim.item_names[0])
else:
assert len(shapes) == 1
return shapes[0]
if not shapes:
return stack_dim
if len(shapes) == 1:
Expand Down Expand Up @@ -2605,6 +2624,48 @@ def shape_stack(stack_dim: Shape, *shapes: Shape, stack_dim_first=True):
return concat_shapes_(stack_dim, *sdims) if stack_dim_first else concat_shapes_(*sdims, stack_dim)


def unstack(self: Shape, dim: str) -> Sequence[Shape]:
"""
Slices this `Shape` along a dimension.
The dimension listing the sizes of the shape is referred to as `'dims'`.
Non-uniform tensor shapes may be unstacked along other dimensions as well, see
https://tum-pbs.github.io/PhiML/Non_Uniform.html
Args:
dim: dimension to unstack
Returns:
slices of this shape
"""
assert dim != 'dims'
if dim not in self and self.is_uniform:
return self,
elif self.is_uniform:
return (self-dim,) * self.get_size(dim)
# --- non-uniform case ---
from ._tensors import Tensor
if dim in self:
inner = self.without(dim)
dim_size = self.get_size(dim)
else:
inner = self
dim_size = self.shape.get_size(dim)
sizes = []
for size in inner.sizes:
if isinstance(size, Tensor) and dim in size.shape:
sizes.append(size._unstack(dim))
dim_size = size.shape.get_size(dim)
else:
sizes.append(size)
assert isinstance(dim_size, int)
result = []
for i in range(dim_size):
sizes_i = [int(size[i]) if isinstance(size, tuple) else size for size in sizes]
result.append(inner.with_sizes(sizes_i))
return result


def prepare_gather(self: Shape, dim: str, selection: Union[slice, int, 'Shape', str, tuple, list]) -> Union[slice, List[int]]:
"""
Parse a slice object for a specific dimension.
Expand Down Expand Up @@ -2736,43 +2797,6 @@ def after_pad(self, widths: dict) -> 'Shape':
return self.with_sizes(sizes)


def unstack(self, dim='dims') -> Tuple['Shape']:
"""
Slices this `Shape` along a dimension.
The dimension listing the sizes of the shape is referred to as `'dims'`.
Non-uniform tensor shapes may be unstacked along other dimensions as well, see
https://tum-pbs.github.io/PhiML/Non_Uniform.html
Args:
dim: dimension to unstack
Returns:
slices of this shape
"""
if dim == 'dims':
return tuple(Shape((self.sizes[i],), (self.names[i],), (self.types[i],), (self.item_names[i],)) for i in range(self.rank))
if dim not in self and self.is_uniform:
return tuple([self])
from ._tensors import Tensor
if dim in self:
inner = self.without(dim)
dim_size = self.get_size(dim)
else:
inner = self
dim_size = self.shape.get_size(dim)
sizes = []
for size in inner.sizes:
if isinstance(size, Tensor) and dim in size.shape:
sizes.append(size._unstack(dim))
dim_size = size.shape.get_size(dim)
else:
sizes.append(size)
assert isinstance(dim_size, int)
shapes = tuple(Shape(tuple([int(size[i]) if isinstance(size, tuple) else size for size in sizes]), inner.names, inner.types, inner.item_names) for i in range(dim_size))
return shapes


def transpose(self, dims: DimFilter):
if callable(dims) and dims in TYPE_BY_FUNCTION:
dims = TYPE_BY_FUNCTION[dims]
Expand Down Expand Up @@ -2824,6 +2848,7 @@ def first_index(shape: Shape):


for cls in [Dim, PureShape, MixedShape]:
cls.unstack = unstack
cls.after_gather = after_gather
cls.after_pad = after_pad
cls.first_index = first_index
Expand Down
19 changes: 11 additions & 8 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
SUPERSCRIPT, IncompatibleShapes, INSTANCE_DIM, batch, spatial, dual, instance, shape, shape as shape_, DimFilter, non_batch, DEBUG_CHECKS, parse_shape_spec,
prepare_renaming_gather, after_gather, concat_shapes_, Dim, PureShape)
from ..backend import NoBackendFound, choose_backend, BACKENDS, get_precision, default_backend, convert as convert_, \
Backend, ComputeDevice, OBJECTS, NUMPY
Backend, ComputeDevice, OBJECTS, NUMPY, ML_LOGGER
from ..backend._dtype import DType, combine_types
from .magic import BoundDim, PhiTreeNode, slicing_dict, Shaped, _BoundDims
from .magic import Shapable
Expand Down Expand Up @@ -1236,8 +1236,8 @@ def __init__(self, native_tensor, names: Sequence[str], expanded_shape: Shape, b
self._names = names
self._backend = backend
if DEBUG_CHECKS:
assert isinstance(names, (tuple, list))
assert all(isinstance(n, str) for n in names)
assert isinstance(names, (tuple, list)), f"names must be a tuple or list[str] but got {type(names)}"
assert all(isinstance(n, str) for n in names), f"names must be a tuple or list[str] but got {names}"
assert isinstance(backend, Backend)
for dim in expanded_shape:
if dim.size is not None and isinstance(dim.size, Tensor):
Expand All @@ -1252,7 +1252,10 @@ def __init__(self, native_tensor, names: Sequence[str], expanded_shape: Shape, b
def native(self, order: Union[str, tuple, list, Shape] = None, force_expand=True):
if order is None:
assert len(self._shape) <= 1, f"When calling Tensor.native() or Tensor.numpy(), the dimension order must be specified for Tensors with more than one dimension, e.g. '{','.join(self._shape.names)}'. The listed default dimension order can vary depending on the chosen backend. Consider using math.reshaped_native(Tensor) instead."
return self._native if len(self._names) <= 1 or not force_expand else self.backend.tile(self._native, (self._shape.size,))
if len(self._names) == len(self._shape):
return self.backend.auto_cast(self._native)[0]
assert len(self._names) == 0 # shape.rank is 1
return self.backend.tile(self.backend.expand_dims(self.backend.auto_cast(self._native)[0]), (self._shape.size,))
if isinstance(order, str):
return self._transposed_native(parse_dim_order(order), force_expand)
if isinstance(order, (tuple, list)) and all(isinstance(o, str) for o in order):
Expand Down Expand Up @@ -1312,7 +1315,7 @@ def _cached(self, dims: Shape = None) -> 'NativeTensor':
if len(self._names) == len(self._shape): # nothing to expand
return self
elif dims is None or len(self._shape) == len(set(self._names) | set(dims.names)): # expand all
return NativeTensor(self.native(order=self._shape), self._shape, self._shape, self._backend)
return NativeTensor(self.native(order=self._shape), self._shape.names, self._shape, self._backend)
else: # expand specific dims
new_native_shape = dims & self._shape[self._names]
tmp_tensor = NativeTensor(self._native, self._names, new_native_shape, self._backend)
Expand Down Expand Up @@ -1696,7 +1699,7 @@ def tensor(data,
return data
else:
if None in shape.sizes:
shape = shape.with_sizes(data.shape)
shape = shape.with_sizes(data.shape.sizes)
return data._with_shape_replaced(shape)
elif isinstance(data, str) or data is None:
return layout(data)
Expand Down Expand Up @@ -2133,7 +2136,7 @@ def expand_tensor(value: Tensor, dims: Shape):
if dims.is_uniform:
return NativeTensor(value._native, value._names, dims & value._shape, value._backend)
else:
stack_dim = dims.shape.without('dims')
stack_dim = dims.non_uniform_shape
if stack_dim.rank > 1:
raise NotImplementedError(f"Higher-order non-uniform expand() not yet supported. Tried expanding {value.shape} by {dims}")
unstacked_dims = [after_gather(dims, i) for i in stack_dim.meshgrid()]
Expand Down Expand Up @@ -3095,7 +3098,7 @@ def backend_for(*values: Tensor):
return result
natives = sum([v._natives() if isinstance(v, Tensor) else (v,) for v in values], ())
result = _BACKEND_RULES[backends] = choose_backend(*natives)
print(f"New backend combination: {backends} -> {result}")
ML_LOGGER.debug(f"Caching new backend combination: {backends} -> {result}")
return result


Expand Down
2 changes: 1 addition & 1 deletion phiml/math/extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ def sparse_pad_values(self, value: Tensor, connectivity: Tensor, dim: str, **kwa

def transform_coordinates(self, coordinates: Tensor, shape: Shape, **kwargs) -> Tensor:
result = []
for dim in shape.spatial.unstack():
for dim in shape.spatial:
dim_coords = coordinates[[dim.name]]
le = self._at_boundary(dim.name+'-')
ue = self._at_boundary(dim.name+'+')
Expand Down
2 changes: 1 addition & 1 deletion tests/commit/math/test__tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def test_single_index_gather(self):

def test_expand_non_uniform(self):
size = vec(batch('dataset_size'), 2, 4, 8, 16, 64, 256)
b = batch(example=size, seed=64) & size
b = batch(example=size, seed=64) & size.shape
t = math.random_uniform(b)
curves = vec(dataset_size=size, fraction=t)
print(curves.shape)
Expand Down
6 changes: 4 additions & 2 deletions tests/commit/math/test_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ def test_pad_collapsed(self):
a = math.zeros(spatial(b=2, x=10, y=10) & batch(batch=10))
p = math.pad(a, {'x': (1, 2)}, ZERO)
self.assertEqual(0, len(p._names))
self.assertEqual((10, 2, 13, 10), p.shape.sizes)
self.assertEqual(13, p.shape.get_size('x'))
self.assertEqual(10, p.shape.get_size('y'))
p = math.pad(a, {'x': (1, 2)}, PERIODIC)
self.assertEqual(0, len(p._names))
self.assertEqual((10, 2, 13, 10), p.shape.sizes)
self.assertEqual(13, p.shape.get_size('x'))
self.assertEqual(10, p.shape.get_size('y'))
# --- 1D ---
p = math.pad(math.ones(spatial(x=3)), {'x': (1, 1)}, 0)
math.assert_close([0, 1, 1, 1, 0], p)
Expand Down

0 comments on commit 5244526

Please sign in to comment.