Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Nov 6, 2024
1 parent ae17de2 commit 4ede112
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions benchmark/tgt_ijcai_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, desc: str, stage: str,
input_time=torch.full((num_src_nodes, ), split_date[split],
dtype=torch.long),
subgraph_type="bidirectional",
batch_size=args.batch_size,
batch_size=train_search_space['batch_size'],
temporal_strategy=args.temporal_strategy,
shuffle=split == "train",
num_workers=args.num_workers,
Expand Down Expand Up @@ -460,7 +460,8 @@ def train_and_eval_with_cfg(
print(f"Train Loss: {train_loss:.4f}")
if epoch % 5 == 0:
# Check if we should early stop
test_pred = test(model, loader_dict["test"], "test", "test")
test_pred = test(model, loader_dict["test"], "test", "test",
target=target_list)
test_metric = calculate_hit_rate(test_pred, target_list)
print(f"Test metric: {test_metric:.4f}")
if test_metric > best_test_metric:
Expand Down

0 comments on commit 4ede112

Please sign in to comment.