From 74e06e7d8872eb906e01ad78474a42614f7e4b72 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 22 Dec 2024 13:44:08 +0100 Subject: [PATCH] Shape refactor 5 * NativeTensor._names --- phiml/math/_ops.py | 44 ++++++++--------- phiml/math/_shape.py | 26 +++++----- phiml/math/_sparse.py | 8 +-- phiml/math/_tensors.py | 108 ++++++++++++++++++++++------------------- 4 files changed, 96 insertions(+), 90 deletions(-) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 8a2041bf..ebd70c5d 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -17,7 +17,7 @@ from ._tensors import (Tensor, wrap, tensor, broadcastable_native_tensors, NativeTensor, TensorStack, custom_op2, compatible_tensor, variable_attributes, disassemble_tree, assemble_tree, is_scalar, Layout, expand_tensor, TensorOrTree, cached, variable_shape, - reshaped_native, reshaped_tensor, discard_constant_dims) + reshaped_native, reshaped_tensor, discard_constant_dims, variable_dim_names) from ._sparse import (CompressedSparseMatrix, dense, SparseCoordinateTensor, get_format, to_format, stored_indices, tensor_like, sparse_dims, same_sparsity_pattern, is_sparse, sparse_dot, sparse_sum, sparse_gather, sparse_max, sparse_min, dense_dims, sparse_mean, stored_values, sparse_matrix_dims, CompactSparseTensor) @@ -742,7 +742,7 @@ def meshgrid(dims: Union[Callable, Shape] = spatial, stack_dim=channel('vector') grid_shape = dim_type(**{dim: size for dim, size in zip(dimensions.keys(), dim_sizes)}) backend = choose_backend(*dim_values, prefer_default=True) indices_list = backend.meshgrid(*dim_values) - channels = [NativeTensor(t, grid_shape) for t in indices_list] + channels = [NativeTensor(t, grid_shape.names, grid_shape) for t in indices_list] if not stack_dim: assert len(channels) == 1, f"meshgrid with multiple dimension requires a valid stack_dim but got {stack_dim}" return channels[0] @@ -882,13 +882,12 @@ def stack_tensors(values: Union[tuple, list], dim: Shape): return TensorStack(values, dim) # --- uniform stack --- dim = dim.with_size(len(values)) - native_shapes = [variable_shape(v) for v in values] - native_broadcast_shape = merge_shapes(*native_shapes) + native_broadcast_shape = merge_shapes(*[variable_shape(v) for v in values]) natives = [reshaped_native(discard_constant_dims(v), [*native_broadcast_shape], force_expand=True) for v in values] - native_shape = native_broadcast_shape & dim - native_stacked = choose_backend(*natives).stack(natives, axis=native_shape.index(dim)) + names = (native_broadcast_shape & dim).names + native_stacked = choose_backend(*natives).stack(natives, axis=names.index(dim.name)) expanded_shape = merge_shapes(*[v.shape for v in values]) & dim - return NativeTensor(native_stacked, native_shape, expanded_shape) + return NativeTensor(native_stacked, names, expanded_shape) def concat_tensor(values: Union[tuple, list], dim: str) -> Tensor: @@ -1259,7 +1258,7 @@ def inner_where(c: Tensor, vt: Tensor, vf: Tensor): return c._with_values(where(c_values, vt_values, vf_values)) shape, (c, vt, vf) = broadcastable_native_tensors(c, vt, vf) result = choose_backend(c, vt, vf).where(c, vt, vf) - return NativeTensor(result, shape) + return NativeTensor(result, shape.names, shape) return broadcast_op(inner_where, [condition, value_true, value_false]) @@ -1453,8 +1452,8 @@ def _sum(value: Tensor, dims: Shape) -> Tensor: if not dims: return value if isinstance(value, NativeTensor): - result = value.default_backend.sum(value._native, value._native_shape.indices(dims)) * value.collapsed_dims.only(dims).volume - return NativeTensor(result, value._native_shape.without(dims), value.shape.without(dims)) + result = value.default_backend.sum(value._native, tuple(value._names.index(n) for n in dims.names)) * value.collapsed_dims.only(dims).volume + return NativeTensor(result, [n for n in value._names if n not in dims], value.shape.without(dims)) elif isinstance(value, TensorStack): reduced_inners = [_sum(t, dims.without(value._stack_dim)) for t in value._tensors] return functools.reduce(lambda x, y: x + y, reduced_inners) if value._stack_dim in dims else TensorStack(reduced_inners, value._stack_dim) @@ -1504,7 +1503,7 @@ def prod(value, dim: DimFilter = non_batch) -> Tensor: def _prod(value: Tensor, dims: Shape) -> Tensor: if isinstance(value, NativeTensor): - result = value.default_backend.prod(value._native, value._native_shape.indices(dims)) ** value.collapsed_dims.only(dims).volume + result = value.default_backend.prod(value._native, value._native_shape.indices(dims.names)) ** value.collapsed_dims.only(dims).volume return NativeTensor(result, value._native_shape.without(dims), value.shape.without(dims)) elif isinstance(value, TensorStack): reduced_inners = [_prod(t, dims.without(value._stack_dim)) for t in value._tensors] @@ -1569,7 +1568,7 @@ def _mean(value: Tensor, dims: Shape) -> Tensor: if not dims: return value if isinstance(value, NativeTensor): - result = value.default_backend.mean(value._native, value._native_shape.indices(dims)) + result = value.default_backend.mean(value._native, value._native_shape.indices(dims.names)) return NativeTensor(result, value._native_shape.without(dims), value.shape.without(dims)) elif isinstance(value, TensorStack): if value._stack_dim in dims: @@ -1611,7 +1610,7 @@ def std(value, dim: DimFilter = non_batch) -> Tensor: def _std(value: Tensor, dims: Shape) -> Tensor: if value.shape.is_uniform: - result = value.default_backend.std(value.native(value.shape), value.shape.indices(dims)) + result = value.default_backend.std(value.native(value.shape), value.shape.indices(dims.names)) return NativeTensor(result, value.shape.without(dims)) else: non_uniform_dims = value.shape.shape.without('dims') @@ -1647,7 +1646,7 @@ def any_(boolean_value, dim: DimFilter = non_batch) -> Tensor: def _any(value: Tensor, dims: Shape) -> Tensor: if isinstance(value, NativeTensor): - result = value.default_backend.any(value._native, value._native_shape.indices(dims)) + result = value.default_backend.any(value._native, value._native_shape.indices(dims.names)) return NativeTensor(result, value._native_shape.without(dims), value.shape.without(dims)) elif isinstance(value, TensorStack): reduced_inners = [_any(t, dims.without(value._stack_dim)) for t in value._tensors] @@ -1681,7 +1680,7 @@ def all_(boolean_value, dim: DimFilter = non_batch) -> Tensor: def _all(value: Tensor, dims: Shape) -> Tensor: if isinstance(value, NativeTensor): - result = value.default_backend.all(value.native(value.shape), value.shape.indices(dims)) + result = value.default_backend.all(value.native(value.shape), value.shape.indices(dims.names)) return NativeTensor(result, value.shape.without(dims)) elif isinstance(value, TensorStack): reduced_inners = [_all(t, dims.without(value._stack_dim)) for t in value._tensors] @@ -1738,7 +1737,7 @@ def _max(value: Tensor, dims: Shape) -> Tensor: if value.shape.volume == 0: return zeros(value.shape.without(dims), dtype=value.dtype) if isinstance(value, NativeTensor): - result = value.default_backend.max(value.native(value.shape), value.shape.indices(dims)) + result = value.default_backend.max(value.native(value.shape), value.shape.indices(dims.names)) return NativeTensor(result, value.shape.without(dims)) elif isinstance(value, TensorStack): reduced_inners = [_max(t, dims.without(value._stack_dim)) for t in value._tensors] @@ -1790,8 +1789,9 @@ def _min(value: Tensor, dims: Shape) -> Tensor: if value.shape.volume == 0: return zeros(value.shape.without(dims), dtype=value.dtype) if isinstance(value, NativeTensor): - result = value.default_backend.min(value.native(value.shape), value.shape.indices(dims)) - return NativeTensor(result, value.shape.without(dims)) + result = value.default_backend.min(value.native(value.shape), value.shape.indices(dims.names)) + new_shape = value.shape.without(dims) + return NativeTensor(result, new_shape.names, new_shape) elif isinstance(value, TensorStack): reduced_inners = [_min(t, dims.without(value._stack_dim)) for t in value._tensors] return functools.reduce(lambda x, y: minimum(x, y), reduced_inners) if value._stack_dim in dims else TensorStack(reduced_inners, value._stack_dim) @@ -2181,7 +2181,7 @@ def tensor_dot(x, y): remaining_shape_y = y.shape.without(y_dims) assert x_dims.volume == y_dims.volume, f"Failed to reduce {x_dims} against {y_dims} in dot product of {x.shape} and {y.shape}. Sizes do not match." if remaining_shape_y.isdisjoint(remaining_shape_x): # no shared batch dimensions -> tensordot - result_native = backend.tensordot(x_native, x.shape.indices(x_dims), y_native, y.shape.indices(y_dims)) + result_native = backend.tensordot(x_native, x.shape.indices(x_dims.names), y_native, y.shape.indices(y_dims.names)) result_shape = concat_shapes(remaining_shape_x, remaining_shape_y) else: # shared batch dimensions -> einsum result_shape = merge_shapes(x.shape.without(x_dims), y.shape.without(y_dims)) @@ -2204,7 +2204,7 @@ def tensor_dot(x, y): keep_letters = [letter_map[dim] for dim in result_shape.names] subscripts = f'{"".join(x_letters)},{"".join(y_letters)}->{"".join(keep_letters)}' result_native = backend.einsum(subscripts, x_native, y_native) - return NativeTensor(result_native, result_shape) + return NativeTensor(result_native, result_shape.names, result_shape) return broadcast_op(tensor_dot, [x, y]) @@ -3127,7 +3127,7 @@ def fft(x: Tensor, dims: DimFilter = spatial) -> Tensor: """ dims = x.shape.only(dims) x_native = x.native(x.shape) - result_native = choose_backend(x_native).fft(x_native, x.shape.indices(dims)) + result_native = choose_backend(x_native).fft(x_native, x.shape.indices(dims.names)) return NativeTensor(result_native, x.shape) @@ -3145,7 +3145,7 @@ def ifft(k: Tensor, dims: DimFilter = spatial): """ dims = k.shape.only(dims) k_native = k.native(k.shape) - result_native = choose_backend(k_native).ifft(k_native, k.shape.indices(dims)) + result_native = choose_backend(k_native).ifft(k_native, k.shape.indices(dims.names)) return NativeTensor(result_native, k.shape) diff --git a/phiml/math/_shape.py b/phiml/math/_shape.py index 6274f8ff..3b0c8b87 100644 --- a/phiml/math/_shape.py +++ b/phiml/math/_shape.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, replace from functools import cached_property from numbers import Number -from typing import Tuple, Callable, List, Union, Any, Sequence, Optional, Dict, Protocol, runtime_checkable +from typing import Tuple, Callable, List, Union, Any, Sequence, Optional, Dict, Protocol, runtime_checkable, Iterable from .. import math @@ -106,7 +106,7 @@ def index(self, dim: Union[str, 'Shape', None]) -> int: """ ... - def indices(self, dims: 'Shape') -> Tuple[int]: + def indices(self, names: Sequence[str]) -> Tuple[int, ...]: """ Finds the indices of the given dimensions within this `Shape`. @@ -114,7 +114,7 @@ def indices(self, dims: 'Shape') -> Tuple[int]: `Shape.index()`. Args: - dims: Sequence of dimensions as `tuple`, `list` or `Shape`. + names: Sequence of dim names as `tuple` or `list`. No name can occur in `names` more than once. Returns: Indices as `tuple[int]`. @@ -936,8 +936,8 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]: return 0 raise ValueError(f"index() requires a single dimension as input but got {dim}") - def indices(self, dims: Shape) -> Tuple[int, ...]: - return tuple([self.index(n) for n in dims.names]) + def indices(self, names: Sequence[str]) -> Tuple[int, ...]: + return (0,) if names else () def __getitem__(self, selection): if isinstance(selection, Shape): @@ -1300,8 +1300,9 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]: return self.names.index(dim.name) raise ValueError(f"index() requires a single dimension as input but got {dim}") - def indices(self, dims: Shape) -> Tuple[int, ...]: - return tuple([self.index(n) for n in dims.names]) + def indices(self, names: Sequence[str]) -> Tuple[int, ...]: + order = self.names + return tuple(order.index(n) for n in names) def __getitem__(self, selection): if isinstance(selection, int): @@ -1663,8 +1664,9 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]: return self.names.index(dim.name) raise ValueError(f"index() requires a single dimension as input but got {dim}") - def indices(self, dims: Shape) -> Tuple[int, ...]: - return tuple([self.index(n) for n in dims.names]) + def indices(self, names: Sequence[str]) -> Tuple[int, ...]: + order = self.names + return tuple(order.index(n) for n in names) def __getitem__(self, selection): if isinstance(selection, int): @@ -1680,10 +1682,8 @@ def __getitem__(self, selection): elif isinstance(selection, Shape): selection = selection.names if isinstance(selection, (tuple, list)): - raise NotImplementedError # this is expensive. Can we replace these calls? - # names = [self.names[s] if isinstance(s, int) else s for s in selection] - # dims = {name: self.dims[name] for name in names} - # selection = [self.index(s) if isinstance(s, str) else s for s in selection] + assert all(isinstance(s, str) for s in selection) + return concat_shapes_(*[self.dims[n] for n in selection]) raise AssertionError("Can only access shape elements as shape[int], shape[str], shape[slice], shape[Sequence] or shape[Shape]") def __iter__(self): diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 27e53614..885567af 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -352,11 +352,11 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], * def _with_shape_replaced(self, new_shape: Shape): assert self._shape.rank == new_shape.rank - dense_shape = new_shape[self._shape.indices(self._dense_shape)] + dense_shape = new_shape[self._shape.indices(self._dense_shape.names)] new_item_names = new_shape[self._shape.indices(self._indices.shape.get_item_names('sparse_idx'))].names values = self._values._with_shape_replaced(self._values.shape.replace(self._shape, new_shape)) non_vec = self._shape.without('sparse_idx') - new_non_vec = new_shape[self._shape.indices(non_vec)] + new_non_vec = new_shape[self._shape.indices(non_vec.names)] indices = self._indices._with_shape_replaced(self._indices.shape.replace(non_vec, new_non_vec).with_dim_size('sparse_idx', new_item_names)) m_rank = self._matrix_rank._with_shape_replaced(self._matrix_rank.shape.replace(self._shape, new_shape)) return SparseCoordinateTensor(indices, values, dense_shape, self._can_contain_double_entries, self._indices_sorted, self._indices_constant, m_rank) @@ -932,7 +932,7 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], * def _with_shape_replaced(self, new_shape: Shape): assert self._shape.rank == new_shape.rank - compressed_dims = new_shape[self._shape.indices(self._compressed_dims)] + compressed_dims = new_shape[self._shape.indices(self._compressed_dims.names)] values = self._values._with_shape_replaced(self._values.shape.replace(self._shape, new_shape)) indices = self._indices._with_shape_replaced(self._indices.shape.replace(self._shape, new_shape)) m_rank = self._matrix_rank._with_shape_replaced(self._matrix_rank.shape.replace(self._shape, new_shape)) @@ -1479,7 +1479,7 @@ def dot_coordinate_dense(sparse: SparseCoordinateTensor, sdims: Shape, dense: Te def dot_compact_dense(compact: CompactSparseTensor, cdims, dense: Tensor, ddims: Shape): - gather_dims = ddims[cdims.indices(compact._compact_dims)] + gather_dims = ddims[cdims.indices(compact._compact_dims.names)] indices = expand(compact._indices, channel(_idx=gather_dims)) dense_gathered = dense[indices] from ._ops import dot diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 2602c771..656d7e46 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -535,7 +535,7 @@ def __unpack_dim__(self, dim: str, unpacked_dims: Shape, **kwargs) -> 'Tensor': new_shape = new_shape.with_sizes(sizes) if new_shape.is_uniform: native_reshaped = choose_backend(native).reshape(native, new_shape.sizes) - return NativeTensor(native_reshaped, new_shape) + return NativeTensor(native_reshaped, new_shape.names, new_shape) else: split_dim = new_shape.non_uniform_shape[-1] i = 0 @@ -566,12 +566,12 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], * order.append(dim) native = self._transposed_native(order, force_expand=True) if pos is None: - pos = min(self.shape.indices(dims)) + pos = min(self.shape.indices(dims.names)) packed_dim = packed_dim.with_sizes([dims.volume]) remaining = self.shape - dims new_shape = concat_shapes_(remaining[:pos], packed_dim, remaining[pos:]) native = choose_backend(native).reshape(native, new_shape.sizes) - return NativeTensor(native, new_shape) + return NativeTensor(native, new_shape.names, new_shape) else: from ._ops import concat_tensor value = cached(self) @@ -1228,34 +1228,34 @@ class NativeTensor(Tensor): The property _shape can contain additional dimensions along which the tensor is constant. """ - def __init__(self, native_tensor, native_shape: Shape, expanded_shape: Shape = None): + def __init__(self, native_tensor, names: Sequence[str], expanded_shape: Shape): super().__init__() - expanded_shape = native_shape if expanded_shape is None else expanded_shape + self._native = native_tensor + self._shape = expanded_shape + self._names = names if DEBUG_CHECKS: + assert isinstance(names, (tuple, list)) + assert all(isinstance(n, str) for n in names) for dim in expanded_shape: if dim.size is not None and isinstance(dim.size, Tensor): assert dim.size.rank > 0 for s_dim in dim.size.shape.names: assert s_dim in expanded_shape.names, f"Dimension {dim} varies along {s_dim} but {s_dim} is not part of the Shape {self}" backend = choose_backend(native_tensor) - assert native_shape.is_uniform assert expanded_shape.is_uniform - assert backend.staticshape(native_tensor) == native_shape.sizes, f"Shape {native_shape} does not match native tensor with shape {backend.staticshape(native_tensor)}" - assert native_shape in expanded_shape - self._native = native_tensor - self._shape = expanded_shape - self._native_shape = native_shape + shape_sizes = [expanded_shape.get_size(n) for n in names] + assert backend.staticshape(native_tensor) == tuple(shape_sizes), f"Shape {expanded_shape} at {names} does not match native tensor with shape {backend.staticshape(native_tensor)}" def _transposed_native(self, order: Sequence[str], force_expand: bool): - assert all([n in order for n in self._native_shape.names]), f"Failed to get native tensor because dims {[n for n in self._native_shape.names if n not in order]} were not specified in the dim order. Got {order} for tensor {self.shape}" + assert all([n in order for n in self._names]), f"Failed to get native tensor because dims {[n for n in self._names if n not in order]} were not specified in the dim order. Got {order} for tensor {self.shape}" backend = self.default_backend - if order == self._native_shape.names: + if order == self._names: if self.dtype.precision in [None, get_precision()]: return self._native else: return backend.cast(self._native, DType(self.dtype.kind, precision=get_precision())) # --- Transpose --- - perm = [self._native_shape.index(dim) for dim in order if dim in self._native_shape] + perm = [self._names.index(n) for n in order if n in self._names] if perm != list(range(len(perm))): transposed = backend.transpose(self._native, perm) # this will cast automatically else: @@ -1263,21 +1263,21 @@ def _transposed_native(self, order: Sequence[str], force_expand: bool): if len(order) == len(perm): return transposed # nothing to expand # --- Expand --- - slices = [slice(None) if dim in self._native_shape else None for dim in order] + slices = [slice(None) if n in self._names else None for n in order] expanded = transposed[tuple(slices)] if force_expand: - multiples = [self._shape.get_size(dim) if dim in self._shape and dim not in self._native_shape else 1 for dim in order] + multiples = [self._shape.get_size(dim) if dim in self._shape and dim not in self._names else 1 for dim in order] expanded = backend.tile(expanded, multiples) return expanded def _contiguous(self): - if self._shape == self._native_shape: + if len(self._names) == len(self._shape): return self expanded = self.native(order=self._shape) return NativeTensor(expanded, self._shape, self._shape) def _cached(self, dims: Shape = None) -> 'NativeTensor': - if self._native_shape == self._shape: # nothing to expand + if len(self._names) == len(self._shape): # nothing to expand return self elif dims is None or self._shape in (dims & self._native_shape): # expand all return NativeTensor(self.native(order=self._shape), self._shape, self._shape) @@ -1288,7 +1288,7 @@ def _cached(self, dims: Shape = None) -> 'NativeTensor': @property def collapsed_dims(self): - return self._shape.without(self._native_shape) + return self._shape.without(self._names) @property def dtype(self): @@ -1306,15 +1306,16 @@ def _with_shape_replaced(self, new_shape): if new_shape.rank != self._shape.rank: raise IncompatibleShapes(f"Tensor {self} is not compatible with shape {new_shape}", self._shape, new_shape) new_shape = new_shape.with_sizes(self._shape.sizes) - native_indices = self._shape.indices(self._native_shape) - new_native_shape = concat_shapes_(*[new_shape[i] for i in native_indices]) - return NativeTensor(self._native, new_native_shape, new_shape) + name_map = {old: new for old, new in zip(self._shape.names, new_shape.names)} + names = [name_map[n] for n in self._names] + return NativeTensor(self._native, names, new_shape) def _with_natives_replaced(self, natives: list): native = natives.pop(0) - new_native_shape = self._native_shape.with_sizes(choose_backend(native).shape(native)) - new_shape = self._shape.with_sizes(new_native_shape) - return NativeTensor(native, new_native_shape, new_shape) + sizes = choose_backend(native).shape(native) + assert sizes == choose_backend(self._native).shape(self._native) + # new_shape = self._shape.with_sizes(new_native_shape) + return NativeTensor(native, self._names, self._shape) @property def _is_tracer(self) -> bool: @@ -1331,30 +1332,31 @@ def _to_dict(self): def _getitem(self, selection: dict): if not selection: return self - selections = [slice(None)] * self._native_shape.rank + selections = [slice(None)] * len(self._names) for name, sel in selection.items(): - if name in self._native_shape: - selections[self._native_shape.index(name)] = sel + assert isinstance(name, str) + if name in self._names: + selections[self._names.index(name)] = sel elif name not in self._shape: assert isinstance(sel, int), f"Attempting slice missing dimension {name} with {selection}" gathered = self.default_backend.multi_slice(self._native, tuple(selections)) if selections else self._native - new_native_shape = after_gather(self._native_shape, selection) - new_shape = after_gather(self._shape, selection) - return NativeTensor(gathered, new_native_shape, new_shape) + new_native_shape = after_gather(self._shape[self._names], selection) + new_shape = self.collapsed_dims & new_native_shape + return NativeTensor(gathered, new_native_shape.names, new_shape) - def _unstack(self, dim): + def _unstack(self, dim: str): new_shape = self._shape.without(dim) - new_native_shape = self._native_shape.without(dim) - if dim in self._native_shape: - tensors = self.default_backend.unstack(self._native, axis=self._native_shape.index(dim)) - return tuple([NativeTensor(t, new_native_shape, new_shape) for t in tensors]) + if dim in self._names: + tensors = self.default_backend.unstack(self._native, axis=self._names.index(dim)) + new_names = [n for n in self._names if n != dim] + return tuple([NativeTensor(t, new_names, new_shape) for t in tensors]) else: assert dim in self._shape, f"Cannot unstack tensor {self._shape} along non-existant dimension '{dim}'" - return (NativeTensor(self._native, new_native_shape, new_shape),) * self._shape.get_size(dim) + return (NativeTensor(self._native, self._names, new_shape),) * self._shape.get_size(dim) def _op1(self, native_function): native = native_function(self._native) - return NativeTensor(native, self._native_shape, self._shape) if native is not None else self + return NativeTensor(native, self._names, self._shape) if native is not None else self def _op2(self, other, operator, native_function, op_name: str = 'unknown', op_symbol: str = '?', switch_args=False): try: @@ -1366,18 +1368,18 @@ def _op2(self, other, operator, native_function, op_name: str = 'unknown', op_sy return NotImplemented if not isinstance(other_tensor, NativeTensor): other_tensor = NativeTensor(other_tensor.native(other_tensor.shape), other_tensor.shape, other_tensor.shape) - broadcast_shape = self._native_shape & other_tensor._native_shape - natives = [t.native(order=broadcast_shape, force_expand=False) if t.rank > 0 else t.native() for t in [self, other_tensor]] + broadcast_names = tuple(set(self._names) | set(other_tensor._names)) + natives = [t.native(order=broadcast_names, force_expand=False) if t.rank > 0 else t.native() for t in [self, other_tensor]] if switch_args: natives = natives[::-1] result_tensor = native_function(*natives) - return NativeTensor(result_tensor, broadcast_shape, self._shape & other_tensor._shape) + return NativeTensor(result_tensor, broadcast_names, self._shape & other_tensor._shape) def _natives(self) -> tuple: return self._native, def _spec_dict(self) -> dict: - return {'type': NativeTensor, 'native_shape': self._native_shape, 'shape': self._shape} + return {'type': NativeTensor, 'names': self._names, 'shape': self._shape} @classmethod def _from_spec_and_natives(cls, spec: dict, natives: list): @@ -1655,7 +1657,7 @@ def tensor(data, assert not shape, f"Trying to create a zero-dimensional Tensor from value '{data}' but shape={shape}" if convert: data = default_backend().as_tensor(data, convert_external=True) - return NativeTensor(data, EMPTY_SHAPE) + return NativeTensor(data, (), EMPTY_SHAPE) if isinstance(data, (tuple, list)): if all(isinstance(d, (bool, int, float, complex, np.generic)) for d in data): array = np.array(data) @@ -1697,7 +1699,7 @@ def tensor(data, if 0 in sizes: present_shape = shape[:len(sizes)].with_sizes(sizes) return NativeTensor(data, present_shape, shape.with_sizes(shape.undefined.with_sizes(0)).with_sizes(present_shape)) - return NativeTensor(data, shape) + return NativeTensor(data, shape.names, shape) except NoBackendFound: raise ValueError(f"{type(data)} is not supported. Only (Tensor, tuple, list, np.ndarray, native tensors) are allowed.\nCurrent backends: {BACKENDS}") @@ -1791,7 +1793,7 @@ def compatible_tensor(data, compat_shape: Shape = None, compat_natives=(), conve except ValueError as e: raise ValueError(e) if len(shape) == 0: - return NativeTensor(data, EMPTY_SHAPE) + return NativeTensor(data, (), EMPTY_SHAPE) elif isinstance(data, (tuple, list)): # always channel, add vector if not available data = backend.as_tensor(data) if len(shape) == compat_shape.channel_rank: @@ -1954,7 +1956,7 @@ def disassemble_tree(obj: PhiTreeNodeType, cache: bool, attr_type=variable_attri sizes = backend.staticshape(obj) dims = [Dim(f"dim{i}", s, CHANNEL_DIM, None) for i, s in enumerate(sizes)] shape = PureShape(CHANNEL_DIM, {dim.name: dim for dim in dims}) - return NATIVE_TENSOR, [NativeTensor(obj, shape)] + return NATIVE_TENSOR, [NativeTensor(obj, shape.names, shape)] except NoBackendFound: return obj, [] @@ -2088,7 +2090,7 @@ def expand_tensor(value: Tensor, dims: Shape): assert dims.well_defined if isinstance(value, NativeTensor): if dims.is_uniform: - return NativeTensor(value._native, value._native_shape, dims & value._shape) + return NativeTensor(value._native, value._names, dims & value._shape) else: stack_dim = dims.shape.without('dims') if stack_dim.rank > 1: @@ -2096,9 +2098,9 @@ def expand_tensor(value: Tensor, dims: Shape): unstacked_dims = [after_gather(dims, i) for i in stack_dim.meshgrid()] if stack_dim in value.shape: unstacked = unstack(value, stack_dim) - components = [NativeTensor(inner._native, inner._native_shape, inner_shape & inner._native_shape) for inner_shape, inner in zip(unstacked_dims, unstacked)] + components = [NativeTensor(inner._native, inner._names, inner_shape & inner._native_shape) for inner_shape, inner in zip(unstacked_dims, unstacked)] else: - components = [NativeTensor(value._native, value._native_shape, inner_shape & value._native_shape) for inner_shape in unstacked_dims] + components = [NativeTensor(value._native, value._names, inner_shape & value._native_shape) for inner_shape in unstacked_dims] return TensorStack(components, stack_dim) if isinstance(value, TensorStack): expanded = [expand_tensor(v, after_gather(dims, {value._stack_dim.name: i})) for i, v in enumerate(value._tensors)] @@ -2938,8 +2940,12 @@ def is_scalar(value) -> bool: return len(choose_backend(value).staticshape(value)) == 0 -def variable_shape(value: Tensor): - return value._native_shape if isinstance(value, NativeTensor) else shape(value) +def variable_shape(value): + return value._shape - value.collapsed_dims if isinstance(value, NativeTensor) else shape(value) + + +def variable_dim_names(value): + return value._names if isinstance(value, NativeTensor) else shape(value).names def may_vary_along(value: Tensor, dims: DimFilter):