diff --git a/torch_frame/data/stats.py b/torch_frame/data/stats.py index fbf6b633..6974a710 100644 --- a/torch_frame/data/stats.py +++ b/torch_frame/data/stats.py @@ -85,7 +85,8 @@ def compute( sep: str | None = None, ) -> Any: if self == StatType.MEAN: - flattened = np.hstack(np.hstack(ser.values)) + val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values + flattened = np.hstack(val) if val.ndim > 1 else val finite_mask = np.isfinite(flattened) if not finite_mask.any(): # NOTE: We may just error out here if eveything is NaN @@ -93,14 +94,16 @@ def compute( return np.mean(flattened[finite_mask]).item() elif self == StatType.STD: - flattened = np.hstack(np.hstack(ser.values)) + val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values + flattened = np.hstack(val) if val.ndim > 1 else val finite_mask = np.isfinite(flattened) if not finite_mask.any(): return np.nan return np.std(flattened[finite_mask]).item() elif self == StatType.QUANTILES: - flattened = np.hstack(np.hstack(ser.values)) + val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values + flattened = np.hstack(val) if val.ndim > 1 else val finite_mask = np.isfinite(flattened) if not finite_mask.any(): return [np.nan, np.nan, np.nan, np.nan, np.nan]