diff --git a/tests/python/pytorch/graphbolt/gb_test_utils.py b/tests/python/pytorch/graphbolt/gb_test_utils.py index dd7abc74da0c..59c4c3a90276 100644 --- a/tests/python/pytorch/graphbolt/gb_test_utils.py +++ b/tests/python/pytorch/graphbolt/gb_test_utils.py @@ -269,7 +269,7 @@ def genereate_raw_data_for_hetero_dataset( # Generate train/test/valid set. os.makedirs(os.path.join(test_dir, "set"), exist_ok=True) user_ids = torch.arange(num_nodes["user"]) - np.random.shuffle(user_ids) + np.random.shuffle(user_ids.numpy()) num_train = int(num_nodes["user"] * 0.6) num_validation = int(num_nodes["user"] * 0.2) num_test = num_nodes["user"] - num_train - num_validation