diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 1902c871..21bcc71c 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -762,9 +762,9 @@ def _native_csr_components(self, invalid='clamp', get_values=True): assert invalid in ['clamp', 'discard', 'keep'] ind_batch = batch(self._indices) & batch(self._pointers) channels = non_instance(self._values).without(ind_batch) - native_indices = self._indices.native([ind_batch, instance]) - native_pointers = self._pointers.native([ind_batch, instance]) - native_values = self._values.native([ind_batch, instance, channels]) if get_values else None + native_indices = self._indices._reshaped_native([ind_batch, instance(self._indices).without_sizes()]) # allow variable instance size (PyTorch tracing) + native_pointers = self._pointers._reshaped_native([ind_batch, instance(self._pointers)]) + native_values = self._values._reshaped_native([ind_batch, instance(self._values).without_sizes(), channels]) if get_values else None native_shape = self._compressed_dims.volume, self._uncompressed_dims.volume if self._uncompressed_offset is not None: native_indices -= self._uncompressed_offset diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index f0613afa..1301f0a7 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1228,7 +1228,7 @@ def __init__(self, native_tensor, names: Sequence[str], expanded_shape: Shape, b for s_dim in dim.size.shape.names: assert s_dim in expanded_shape.names, f"Dimension {dim} varies along {s_dim} but {s_dim} is not part of the Shape {self}" assert choose_backend(native_tensor) == backend - assert expanded_shape.is_uniform + assert expanded_shape.is_uniform, expanded_shape shape_sizes = [expanded_shape.get_size(n) for n in names] assert backend.staticshape(native_tensor) == tuple(shape_sizes), f"Shape {expanded_shape} at {names} does not match native tensor with shape {backend.staticshape(native_tensor)}" @@ -1264,7 +1264,7 @@ def _reshaped_native(self, groups: Sequence[Shape]): native = self._backend.transpose(self._native, perm) # this will cast automatically native = native[tuple(slices)] native = self._backend.tile(native, tile) - native = self._backend.reshape(native, [g.volume for g in groups]) + native = self._backend.reshape(native, [g.volume if g.well_defined else -1 for g in groups]) return native def _transposed_native(self, order: Sequence[str], force_expand: bool):