diff --git a/src/tensortrax/__about__.py b/src/tensortrax/__about__.py index 32e6e64..3452608 100644 --- a/src/tensortrax/__about__.py +++ b/src/tensortrax/__about__.py @@ -2,4 +2,4 @@ tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes. """ -__version__ = "0.21.2" +__version__ = "0.21.3" diff --git a/src/tensortrax/math/special/_special_tensor.py b/src/tensortrax/math/special/_special_tensor.py index 9e85ddb..72108ea 100644 --- a/src/tensortrax/math/special/_special_tensor.py +++ b/src/tensortrax/math/special/_special_tensor.py @@ -82,6 +82,6 @@ def from_triu_2d(A): def try_stack(arrays, fallback=None): "Try to unpack and stack the list of tensors and return the fallback otherwise." try: - return stack([A for ary in arrays for A in ary]) + return stack([A.x for ary in arrays for A in ary]) except ValueError: return fallback diff --git a/tests/test_math.py b/tests/test_math.py index a46d695..4f67368 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -255,7 +255,7 @@ def test_try_stack(): fallback = "my fallback" stacked = tm.special.try_stack([C6, C6], fallback=fallback) - assert stacked.shape == (12,) + assert stacked.shape[0] == 12 assert tm.special.try_stack([C, C6], fallback=fallback) == fallback with pytest.raises(ValueError):