Skip to content

Commit

Permalink
add calculation for test
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 2, 2024
1 parent 598a5a3 commit 7568500
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions scripts/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from torch_geometric.typing import NodeType

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rel-hm")
parser.add_argument("--task", type=str, default="user-item-purchase")
parser.add_argument("--dataset", type=str, default="rel-amazon")
parser.add_argument("--task", type=str, default="user-item-review")

parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--channels", type=int, default=512)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=2)
Expand Down Expand Up @@ -113,6 +113,7 @@
test_sparse_tensor = SparseTensor(dst_nodes_dict["test"][1], device=device)

score = num_examples = 0

for batch in loader_dict["val"]:
# batch.to(device)

Expand Down Expand Up @@ -142,6 +143,37 @@
print(args.dataset, args.task, args.num_layers)
print(score / num_examples)


score = num_examples = 0
for batch in loader_dict["test"]:
# batch.to(device)

rhs = batch[task.dst_entity_table].n_id
rhs_batch = batch[task.dst_entity_table].batch
batch_size = batch[task.src_entity_table].batch_size

input_id = batch[task.src_entity_table].input_id
# Obtain ground truth seen during training
val_src_batch, val_dst_index = test_sparse_tensor[input_id]

# map to 1d-vectors
rhs = rhs_batch * num_rhs_nodes + rhs
ground_truth_rhs = val_src_batch * num_rhs_nodes + val_dst_index

seen = torch.isin(ground_truth_rhs, rhs).long()

from torch_scatter import scatter_add
ground_truth_count = scatter_add(torch.ones_like(val_src_batch),
val_src_batch, dim_size=batch_size)

seen = scatter_add(seen.long(), val_src_batch, dim_size=batch_size)

score += (seen / ground_truth_count).sum().item()
num_examples += batch_size

print(args.dataset, args.task, args.num_layers)
print(score / num_examples)

# test_table = task.get_table('test')
# test_df = test_table.df
# test_seen_percent = []
Expand Down

0 comments on commit 7568500

Please sign in to comment.