From d91e36509ddcf1270ad612dbeba29d69ad81f543 Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Tue, 19 Nov 2024 06:40:06 +0000 Subject: [PATCH] incorporate rhs sample size --- benchmark/efficiency.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/benchmark/efficiency.py b/benchmark/efficiency.py index acbacf1..2726c61 100644 --- a/benchmark/efficiency.py +++ b/benchmark/efficiency.py @@ -37,7 +37,6 @@ warnings.filterwarnings("ignore", category=FutureWarning) - parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="rel-trial") parser.add_argument("--task", type=str, default="site-sponsor-run") @@ -58,7 +57,6 @@ default=os.path.expanduser("~/.cache/relbench_examples")) args = parser.parse_args() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.set_num_threads(1) @@ -202,7 +200,8 @@ def train() -> float: ).float() loss = F.binary_cross_entropy_with_logits(out, target) elif model_type in ['contextgnn', 'shallowrhsgnn']: - logits = model(batch, task.src_entity_table, task.dst_entity_table) + logits = model(batch, task.src_entity_table, + task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) @@ -234,7 +233,8 @@ def train() -> float: ).float() loss = F.binary_cross_entropy_with_logits(out, target) elif model_type in ['contextgnn', 'shallowrhsgnn']: - logits = model(batch, task.src_entity_table, task.dst_entity_table) + logits = model(batch, task.src_entity_table, + task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) @@ -250,8 +250,7 @@ def train() -> float: gpu_time = start.elapsed_time(end) gpu_time_in_s = gpu_time / 1_000 print( - f"model: {model_type}, ", - f"total: {gpu_time_in_s} s, " + f"model: {model_type}, ", f"total: {gpu_time_in_s} s, " f"avg: {gpu_time_in_s / num_steps} s/iter, " f"avg: {num_steps / gpu_time_in_s} iter/s")