From 7568500121822321b844d77217fc47afab90299d Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Wed, 2 Oct 2024 05:07:37 +0000 Subject: [PATCH] add calculation for test --- scripts/subgraph.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/scripts/subgraph.py b/scripts/subgraph.py index 0fdf64e..21c0f4a 100644 --- a/scripts/subgraph.py +++ b/scripts/subgraph.py @@ -25,12 +25,12 @@ from torch_geometric.typing import NodeType parser = argparse.ArgumentParser() -parser.add_argument("--dataset", type=str, default="rel-hm") -parser.add_argument("--task", type=str, default="user-item-purchase") +parser.add_argument("--dataset", type=str, default="rel-amazon") +parser.add_argument("--task", type=str, default="user-item-review") parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--eval_epochs_interval", type=int, default=1) -parser.add_argument("--batch_size", type=int, default=64) +parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--channels", type=int, default=512) parser.add_argument("--aggr", type=str, default="sum") parser.add_argument("--num_layers", type=int, default=2) @@ -113,6 +113,7 @@ test_sparse_tensor = SparseTensor(dst_nodes_dict["test"][1], device=device) score = num_examples = 0 + for batch in loader_dict["val"]: # batch.to(device) @@ -142,6 +143,37 @@ print(args.dataset, args.task, args.num_layers) print(score / num_examples) + +score = num_examples = 0 +for batch in loader_dict["test"]: + # batch.to(device) + + rhs = batch[task.dst_entity_table].n_id + rhs_batch = batch[task.dst_entity_table].batch + batch_size = batch[task.src_entity_table].batch_size + + input_id = batch[task.src_entity_table].input_id + # Obtain ground truth seen during training + val_src_batch, val_dst_index = test_sparse_tensor[input_id] + + # map to 1d-vectors + rhs = rhs_batch * num_rhs_nodes + rhs + ground_truth_rhs = val_src_batch * num_rhs_nodes + val_dst_index + + seen = torch.isin(ground_truth_rhs, rhs).long() + + from torch_scatter import scatter_add + ground_truth_count = scatter_add(torch.ones_like(val_src_batch), + val_src_batch, dim_size=batch_size) + + seen = scatter_add(seen.long(), val_src_batch, dim_size=batch_size) + + score += (seen / ground_truth_count).sum().item() + num_examples += batch_size + +print(args.dataset, args.task, args.num_layers) +print(score / num_examples) + # test_table = task.get_table('test') # test_df = test_table.df # test_seen_percent = []