Skip to content

Commit

Permalink
Shape refactor 5
Browse files Browse the repository at this point in the history
* NativeTensor._names
  • Loading branch information
holl- committed Dec 22, 2024
1 parent 26296ab commit 74e06e7
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 90 deletions.
44 changes: 22 additions & 22 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._tensors import (Tensor, wrap, tensor, broadcastable_native_tensors, NativeTensor, TensorStack,
custom_op2, compatible_tensor, variable_attributes, disassemble_tree, assemble_tree,
is_scalar, Layout, expand_tensor, TensorOrTree, cached, variable_shape,
reshaped_native, reshaped_tensor, discard_constant_dims)
reshaped_native, reshaped_tensor, discard_constant_dims, variable_dim_names)
from ._sparse import (CompressedSparseMatrix, dense, SparseCoordinateTensor, get_format, to_format, stored_indices,
tensor_like, sparse_dims, same_sparsity_pattern, is_sparse, sparse_dot, sparse_sum, sparse_gather, sparse_max,
sparse_min, dense_dims, sparse_mean, stored_values, sparse_matrix_dims, CompactSparseTensor)
Expand Down Expand Up @@ -742,7 +742,7 @@ def meshgrid(dims: Union[Callable, Shape] = spatial, stack_dim=channel('vector')
grid_shape = dim_type(**{dim: size for dim, size in zip(dimensions.keys(), dim_sizes)})
backend = choose_backend(*dim_values, prefer_default=True)
indices_list = backend.meshgrid(*dim_values)
channels = [NativeTensor(t, grid_shape) for t in indices_list]
channels = [NativeTensor(t, grid_shape.names, grid_shape) for t in indices_list]
if not stack_dim:
assert len(channels) == 1, f"meshgrid with multiple dimension requires a valid stack_dim but got {stack_dim}"
return channels[0]
Expand Down Expand Up @@ -882,13 +882,12 @@ def stack_tensors(values: Union[tuple, list], dim: Shape):
return TensorStack(values, dim)
# --- uniform stack ---
dim = dim.with_size(len(values))
native_shapes = [variable_shape(v) for v in values]
native_broadcast_shape = merge_shapes(*native_shapes)
native_broadcast_shape = merge_shapes(*[variable_shape(v) for v in values])
natives = [reshaped_native(discard_constant_dims(v), [*native_broadcast_shape], force_expand=True) for v in values]
native_shape = native_broadcast_shape & dim
native_stacked = choose_backend(*natives).stack(natives, axis=native_shape.index(dim))
names = (native_broadcast_shape & dim).names
native_stacked = choose_backend(*natives).stack(natives, axis=names.index(dim.name))
expanded_shape = merge_shapes(*[v.shape for v in values]) & dim
return NativeTensor(native_stacked, native_shape, expanded_shape)
return NativeTensor(native_stacked, names, expanded_shape)


def concat_tensor(values: Union[tuple, list], dim: str) -> Tensor:
Expand Down Expand Up @@ -1259,7 +1258,7 @@ def inner_where(c: Tensor, vt: Tensor, vf: Tensor):
return c._with_values(where(c_values, vt_values, vf_values))
shape, (c, vt, vf) = broadcastable_native_tensors(c, vt, vf)
result = choose_backend(c, vt, vf).where(c, vt, vf)
return NativeTensor(result, shape)
return NativeTensor(result, shape.names, shape)

return broadcast_op(inner_where, [condition, value_true, value_false])

Expand Down Expand Up @@ -1453,8 +1452,8 @@ def _sum(value: Tensor, dims: Shape) -> Tensor:
if not dims:
return value
if isinstance(value, NativeTensor):
result = value.default_backend.sum(value._native, value._native_shape.indices(dims)) * value.collapsed_dims.only(dims).volume
return NativeTensor(result, value._native_shape.without(dims), value.shape.without(dims))
result = value.default_backend.sum(value._native, tuple(value._names.index(n) for n in dims.names)) * value.collapsed_dims.only(dims).volume
return NativeTensor(result, [n for n in value._names if n not in dims], value.shape.without(dims))
elif isinstance(value, TensorStack):
reduced_inners = [_sum(t, dims.without(value._stack_dim)) for t in value._tensors]
return functools.reduce(lambda x, y: x + y, reduced_inners) if value._stack_dim in dims else TensorStack(reduced_inners, value._stack_dim)
Expand Down Expand Up @@ -1504,7 +1503,7 @@ def prod(value, dim: DimFilter = non_batch) -> Tensor:

def _prod(value: Tensor, dims: Shape) -> Tensor:
if isinstance(value, NativeTensor):
result = value.default_backend.prod(value._native, value._native_shape.indices(dims)) ** value.collapsed_dims.only(dims).volume
result = value.default_backend.prod(value._native, value._native_shape.indices(dims.names)) ** value.collapsed_dims.only(dims).volume
return NativeTensor(result, value._native_shape.without(dims), value.shape.without(dims))
elif isinstance(value, TensorStack):
reduced_inners = [_prod(t, dims.without(value._stack_dim)) for t in value._tensors]
Expand Down Expand Up @@ -1569,7 +1568,7 @@ def _mean(value: Tensor, dims: Shape) -> Tensor:
if not dims:
return value
if isinstance(value, NativeTensor):
result = value.default_backend.mean(value._native, value._native_shape.indices(dims))
result = value.default_backend.mean(value._native, value._native_shape.indices(dims.names))
return NativeTensor(result, value._native_shape.without(dims), value.shape.without(dims))
elif isinstance(value, TensorStack):
if value._stack_dim in dims:
Expand Down Expand Up @@ -1611,7 +1610,7 @@ def std(value, dim: DimFilter = non_batch) -> Tensor:

def _std(value: Tensor, dims: Shape) -> Tensor:
if value.shape.is_uniform:
result = value.default_backend.std(value.native(value.shape), value.shape.indices(dims))
result = value.default_backend.std(value.native(value.shape), value.shape.indices(dims.names))
return NativeTensor(result, value.shape.without(dims))
else:
non_uniform_dims = value.shape.shape.without('dims')
Expand Down Expand Up @@ -1647,7 +1646,7 @@ def any_(boolean_value, dim: DimFilter = non_batch) -> Tensor:

def _any(value: Tensor, dims: Shape) -> Tensor:
if isinstance(value, NativeTensor):
result = value.default_backend.any(value._native, value._native_shape.indices(dims))
result = value.default_backend.any(value._native, value._native_shape.indices(dims.names))
return NativeTensor(result, value._native_shape.without(dims), value.shape.without(dims))
elif isinstance(value, TensorStack):
reduced_inners = [_any(t, dims.without(value._stack_dim)) for t in value._tensors]
Expand Down Expand Up @@ -1681,7 +1680,7 @@ def all_(boolean_value, dim: DimFilter = non_batch) -> Tensor:

def _all(value: Tensor, dims: Shape) -> Tensor:
if isinstance(value, NativeTensor):
result = value.default_backend.all(value.native(value.shape), value.shape.indices(dims))
result = value.default_backend.all(value.native(value.shape), value.shape.indices(dims.names))
return NativeTensor(result, value.shape.without(dims))
elif isinstance(value, TensorStack):
reduced_inners = [_all(t, dims.without(value._stack_dim)) for t in value._tensors]
Expand Down Expand Up @@ -1738,7 +1737,7 @@ def _max(value: Tensor, dims: Shape) -> Tensor:
if value.shape.volume == 0:
return zeros(value.shape.without(dims), dtype=value.dtype)
if isinstance(value, NativeTensor):
result = value.default_backend.max(value.native(value.shape), value.shape.indices(dims))
result = value.default_backend.max(value.native(value.shape), value.shape.indices(dims.names))
return NativeTensor(result, value.shape.without(dims))
elif isinstance(value, TensorStack):
reduced_inners = [_max(t, dims.without(value._stack_dim)) for t in value._tensors]
Expand Down Expand Up @@ -1790,8 +1789,9 @@ def _min(value: Tensor, dims: Shape) -> Tensor:
if value.shape.volume == 0:
return zeros(value.shape.without(dims), dtype=value.dtype)
if isinstance(value, NativeTensor):
result = value.default_backend.min(value.native(value.shape), value.shape.indices(dims))
return NativeTensor(result, value.shape.without(dims))
result = value.default_backend.min(value.native(value.shape), value.shape.indices(dims.names))
new_shape = value.shape.without(dims)
return NativeTensor(result, new_shape.names, new_shape)
elif isinstance(value, TensorStack):
reduced_inners = [_min(t, dims.without(value._stack_dim)) for t in value._tensors]
return functools.reduce(lambda x, y: minimum(x, y), reduced_inners) if value._stack_dim in dims else TensorStack(reduced_inners, value._stack_dim)
Expand Down Expand Up @@ -2181,7 +2181,7 @@ def tensor_dot(x, y):
remaining_shape_y = y.shape.without(y_dims)
assert x_dims.volume == y_dims.volume, f"Failed to reduce {x_dims} against {y_dims} in dot product of {x.shape} and {y.shape}. Sizes do not match."
if remaining_shape_y.isdisjoint(remaining_shape_x): # no shared batch dimensions -> tensordot
result_native = backend.tensordot(x_native, x.shape.indices(x_dims), y_native, y.shape.indices(y_dims))
result_native = backend.tensordot(x_native, x.shape.indices(x_dims.names), y_native, y.shape.indices(y_dims.names))
result_shape = concat_shapes(remaining_shape_x, remaining_shape_y)
else: # shared batch dimensions -> einsum
result_shape = merge_shapes(x.shape.without(x_dims), y.shape.without(y_dims))
Expand All @@ -2204,7 +2204,7 @@ def tensor_dot(x, y):
keep_letters = [letter_map[dim] for dim in result_shape.names]
subscripts = f'{"".join(x_letters)},{"".join(y_letters)}->{"".join(keep_letters)}'
result_native = backend.einsum(subscripts, x_native, y_native)
return NativeTensor(result_native, result_shape)
return NativeTensor(result_native, result_shape.names, result_shape)

return broadcast_op(tensor_dot, [x, y])

Expand Down Expand Up @@ -3127,7 +3127,7 @@ def fft(x: Tensor, dims: DimFilter = spatial) -> Tensor:
"""
dims = x.shape.only(dims)
x_native = x.native(x.shape)
result_native = choose_backend(x_native).fft(x_native, x.shape.indices(dims))
result_native = choose_backend(x_native).fft(x_native, x.shape.indices(dims.names))
return NativeTensor(result_native, x.shape)


Expand All @@ -3145,7 +3145,7 @@ def ifft(k: Tensor, dims: DimFilter = spatial):
"""
dims = k.shape.only(dims)
k_native = k.native(k.shape)
result_native = choose_backend(k_native).ifft(k_native, k.shape.indices(dims))
result_native = choose_backend(k_native).ifft(k_native, k.shape.indices(dims.names))
return NativeTensor(result_native, k.shape)


Expand Down
26 changes: 13 additions & 13 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, replace
from functools import cached_property
from numbers import Number
from typing import Tuple, Callable, List, Union, Any, Sequence, Optional, Dict, Protocol, runtime_checkable
from typing import Tuple, Callable, List, Union, Any, Sequence, Optional, Dict, Protocol, runtime_checkable, Iterable

from .. import math

Expand Down Expand Up @@ -106,15 +106,15 @@ def index(self, dim: Union[str, 'Shape', None]) -> int:
"""
...

def indices(self, dims: 'Shape') -> Tuple[int]:
def indices(self, names: Sequence[str]) -> Tuple[int, ...]:
"""
Finds the indices of the given dimensions within this `Shape`.
See Also:
`Shape.index()`.
Args:
dims: Sequence of dimensions as `tuple`, `list` or `Shape`.
names: Sequence of dim names as `tuple` or `list`. No name can occur in `names` more than once.
Returns:
Indices as `tuple[int]`.
Expand Down Expand Up @@ -936,8 +936,8 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]:
return 0
raise ValueError(f"index() requires a single dimension as input but got {dim}")

def indices(self, dims: Shape) -> Tuple[int, ...]:
return tuple([self.index(n) for n in dims.names])
def indices(self, names: Sequence[str]) -> Tuple[int, ...]:
return (0,) if names else ()

def __getitem__(self, selection):
if isinstance(selection, Shape):
Expand Down Expand Up @@ -1300,8 +1300,9 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]:
return self.names.index(dim.name)
raise ValueError(f"index() requires a single dimension as input but got {dim}")

def indices(self, dims: Shape) -> Tuple[int, ...]:
return tuple([self.index(n) for n in dims.names])
def indices(self, names: Sequence[str]) -> Tuple[int, ...]:
order = self.names
return tuple(order.index(n) for n in names)

def __getitem__(self, selection):
if isinstance(selection, int):
Expand Down Expand Up @@ -1663,8 +1664,9 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]:
return self.names.index(dim.name)
raise ValueError(f"index() requires a single dimension as input but got {dim}")

def indices(self, dims: Shape) -> Tuple[int, ...]:
return tuple([self.index(n) for n in dims.names])
def indices(self, names: Sequence[str]) -> Tuple[int, ...]:
order = self.names
return tuple(order.index(n) for n in names)

def __getitem__(self, selection):
if isinstance(selection, int):
Expand All @@ -1680,10 +1682,8 @@ def __getitem__(self, selection):
elif isinstance(selection, Shape):
selection = selection.names
if isinstance(selection, (tuple, list)):
raise NotImplementedError # this is expensive. Can we replace these calls?
# names = [self.names[s] if isinstance(s, int) else s for s in selection]
# dims = {name: self.dims[name] for name in names}
# selection = [self.index(s) if isinstance(s, str) else s for s in selection]
assert all(isinstance(s, str) for s in selection)
return concat_shapes_(*[self.dims[n] for n in selection])
raise AssertionError("Can only access shape elements as shape[int], shape[str], shape[slice], shape[Sequence] or shape[Shape]")

def __iter__(self):
Expand Down
8 changes: 4 additions & 4 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], *

def _with_shape_replaced(self, new_shape: Shape):
assert self._shape.rank == new_shape.rank
dense_shape = new_shape[self._shape.indices(self._dense_shape)]
dense_shape = new_shape[self._shape.indices(self._dense_shape.names)]
new_item_names = new_shape[self._shape.indices(self._indices.shape.get_item_names('sparse_idx'))].names
values = self._values._with_shape_replaced(self._values.shape.replace(self._shape, new_shape))
non_vec = self._shape.without('sparse_idx')
new_non_vec = new_shape[self._shape.indices(non_vec)]
new_non_vec = new_shape[self._shape.indices(non_vec.names)]
indices = self._indices._with_shape_replaced(self._indices.shape.replace(non_vec, new_non_vec).with_dim_size('sparse_idx', new_item_names))
m_rank = self._matrix_rank._with_shape_replaced(self._matrix_rank.shape.replace(self._shape, new_shape))
return SparseCoordinateTensor(indices, values, dense_shape, self._can_contain_double_entries, self._indices_sorted, self._indices_constant, m_rank)
Expand Down Expand Up @@ -932,7 +932,7 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], *

def _with_shape_replaced(self, new_shape: Shape):
assert self._shape.rank == new_shape.rank
compressed_dims = new_shape[self._shape.indices(self._compressed_dims)]
compressed_dims = new_shape[self._shape.indices(self._compressed_dims.names)]
values = self._values._with_shape_replaced(self._values.shape.replace(self._shape, new_shape))
indices = self._indices._with_shape_replaced(self._indices.shape.replace(self._shape, new_shape))
m_rank = self._matrix_rank._with_shape_replaced(self._matrix_rank.shape.replace(self._shape, new_shape))
Expand Down Expand Up @@ -1479,7 +1479,7 @@ def dot_coordinate_dense(sparse: SparseCoordinateTensor, sdims: Shape, dense: Te


def dot_compact_dense(compact: CompactSparseTensor, cdims, dense: Tensor, ddims: Shape):
gather_dims = ddims[cdims.indices(compact._compact_dims)]
gather_dims = ddims[cdims.indices(compact._compact_dims.names)]
indices = expand(compact._indices, channel(_idx=gather_dims))
dense_gathered = dense[indices]
from ._ops import dot
Expand Down
Loading

0 comments on commit 74e06e7

Please sign in to comment.