Skip to content

Commit

Permalink
incorporate rhs sample size
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Nov 19, 2024
1 parent 8af2cab commit d91e365
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions benchmark/efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

warnings.filterwarnings("ignore", category=FutureWarning)


parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rel-trial")
parser.add_argument("--task", type=str, default="site-sponsor-run")
Expand All @@ -58,7 +57,6 @@
default=os.path.expanduser("~/.cache/relbench_examples"))
args = parser.parse_args()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.set_num_threads(1)
Expand Down Expand Up @@ -202,7 +200,8 @@ def train() -> float:
).float()
loss = F.binary_cross_entropy_with_logits(out, target)
elif model_type in ['contextgnn', 'shallowrhsgnn']:
logits = model(batch, task.src_entity_table, task.dst_entity_table)
logits = model(batch, task.src_entity_table,
task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)

Expand Down Expand Up @@ -234,7 +233,8 @@ def train() -> float:
).float()
loss = F.binary_cross_entropy_with_logits(out, target)
elif model_type in ['contextgnn', 'shallowrhsgnn']:
logits = model(batch, task.src_entity_table, task.dst_entity_table)
logits = model(batch, task.src_entity_table,
task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)

Expand All @@ -250,8 +250,7 @@ def train() -> float:
gpu_time = start.elapsed_time(end)
gpu_time_in_s = gpu_time / 1_000
print(
f"model: {model_type}, ",
f"total: {gpu_time_in_s} s, "
f"model: {model_type}, ", f"total: {gpu_time_in_s} s, "
f"avg: {gpu_time_in_s / num_steps} s/iter, "
f"avg: {num_steps / gpu_time_in_s} iter/s")

Expand Down

0 comments on commit d91e365

Please sign in to comment.