diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index 62d999f1..6ffb529c 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -69,10 +69,12 @@ def __init__( feat_dict: dict[torch_frame.stype, TensorData], col_names_dict: dict[torch_frame.stype, list[str]], y: Tensor | None = None, + num_rows: int | None = None, ) -> None: self.feat_dict = feat_dict self.col_names_dict = col_names_dict self.y = y + self._num_rows = num_rows self.validate() # Quick mapping from column names into their (stype, idx) pairs in @@ -175,6 +177,8 @@ def num_cols(self) -> int: @property def num_rows(self) -> int: r"""The number of rows in the :class:`TensorFrame`.""" + if self._num_rows is not None: + return self._num_rows if self.is_empty: return 0 feat = next(iter(self.feat_dict.values()))