diff --git a/test/data/test_tensor_frame.py b/test/data/test_tensor_frame.py index 6c283ecc..a0645b18 100644 --- a/test/data/test_tensor_frame.py +++ b/test/data/test_tensor_frame.py @@ -245,3 +245,11 @@ def test_custom_tf_get_col_feat(): assert torch.equal(feat, feat_dict['numerical'][:, 0:1]) feat = tf.get_col_feat('num_2') assert torch.equal(feat, feat_dict['numerical'][:, 1:2]) + + +def test_non_list_col_names_dict(): + feat_dict = {torch_frame.categorical: torch.randint(0, 3, size=(10, 1))} + # Oops, user provided a single column name without wrapping it in a list: + col_names_dict = {torch_frame.categorical: 'cat_1'} + with pytest.raises(ValueError, match='must be a list of column names'): + TensorFrame(feat_dict, col_names_dict) diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index c9b9c56d..e88d8137 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -95,7 +95,13 @@ def validate(self) -> None: num_rows = self.num_rows empty_stypes: list[torch_frame.stype] = [] for stype_name, feats in self.feat_dict.items(): - num_cols = len(self.col_names_dict[stype_name]) + col_names = self.col_names_dict[stype_name] + if not isinstance(col_names, list): + raise ValueError( + f"col_names_dict[{stype_name}] must be a list of column " + f"names.") + + num_cols = len(col_names) if num_cols == 0: empty_stypes.append(stype_name)