From 4ede11280b494de6e6a0f815c8037de4742ae9ce Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Wed, 6 Nov 2024 08:01:49 +0000 Subject: [PATCH] fix bug --- benchmark/tgt_ijcai_benchmark.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmark/tgt_ijcai_benchmark.py b/benchmark/tgt_ijcai_benchmark.py index f264c60..2e422a3 100644 --- a/benchmark/tgt_ijcai_benchmark.py +++ b/benchmark/tgt_ijcai_benchmark.py @@ -401,7 +401,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, desc: str, stage: str, input_time=torch.full((num_src_nodes, ), split_date[split], dtype=torch.long), subgraph_type="bidirectional", - batch_size=args.batch_size, + batch_size=train_search_space['batch_size'], temporal_strategy=args.temporal_strategy, shuffle=split == "train", num_workers=args.num_workers, @@ -460,7 +460,8 @@ def train_and_eval_with_cfg( print(f"Train Loss: {train_loss:.4f}") if epoch % 5 == 0: # Check if we should early stop - test_pred = test(model, loader_dict["test"], "test", "test") + test_pred = test(model, loader_dict["test"], "test", "test", + target=target_list) test_metric = calculate_hit_rate(test_pred, target_list) print(f"Test metric: {test_metric:.4f}") if test_metric > best_test_metric: