Skip to content

Commit

Permalink
fix efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Nov 19, 2024
1 parent d91e365 commit 3fd460e
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions benchmark/efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--rhs_sample_size", type=int, default=10)
parser.add_argument("--rhs_sample_size", type=int, default=-1)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
Expand Down Expand Up @@ -141,7 +141,7 @@ def create_model(model_type: str) -> Union[IDGNN, ContextGNN, ShallowRHSGNN]:
return ContextGNN(
data=data,
col_stats_dict=col_stats_dict,
rhs_emb_mode=RHSEmbeddingMode.FUSION,
rhs_emb_mode=RHSEmbeddingMode.LOOKUP,
dst_entity_table=task.dst_entity_table,
num_nodes=num_dst_nodes_dict["train"],
num_layers=args.num_layers,
Expand All @@ -153,6 +153,8 @@ def create_model(model_type: str) -> Union[IDGNN, ContextGNN, ShallowRHSGNN]:
"channels": 64,
"num_layers": 4,
},
rhs_sample_size=None
if args.rhs_sample_size < 0 else args.rhs_sample_size,
).to(device)
elif model_type == 'shallowrhsgnn':
return ShallowRHSGNN(
Expand All @@ -179,7 +181,7 @@ def create_model(model_type: str) -> Union[IDGNN, ContextGNN, ShallowRHSGNN]:
model = create_model(model_type)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

def train() -> float:
def train():
model.train()

print("warming up...")
Expand All @@ -200,9 +202,18 @@ def train() -> float:
).float()
loss = F.binary_cross_entropy_with_logits(out, target)
elif model_type in ['contextgnn', 'shallowrhsgnn']:
logits = model(batch, task.src_entity_table,
task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
if args.rhs_sample_size < 0:
logits = model(batch, task.src_entity_table,
task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index],
dim=0)
else:
(logits, lhs_y_batch,
rhs_y_index) = model.forward_sample_softmax(
batch, task.src_entity_table, task.dst_entity_table,
src_batch, dst_index)
edge_label_index = torch.stack([lhs_y_batch, rhs_y_index],
dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)

loss.backward()
Expand Down Expand Up @@ -233,9 +244,18 @@ def train() -> float:
).float()
loss = F.binary_cross_entropy_with_logits(out, target)
elif model_type in ['contextgnn', 'shallowrhsgnn']:
logits = model(batch, task.src_entity_table,
task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
if args.rhs_sample_size < 0:
logits = model(batch, task.src_entity_table,
task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index],
dim=0)
else:
(logits, lhs_y_batch,
rhs_y_index) = model.forward_sample_softmax(
batch, task.src_entity_table, task.dst_entity_table,
src_batch, dst_index)
edge_label_index = torch.stack([lhs_y_batch, rhs_y_index],
dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)

loss.backward()
Expand All @@ -245,7 +265,7 @@ def train() -> float:
print(f"done at {i}th step")
break

end.record() # type: ignore
end.record()
torch.cuda.synchronize()
gpu_time = start.elapsed_time(end)
gpu_time_in_s = gpu_time / 1_000
Expand Down

0 comments on commit 3fd460e

Please sign in to comment.