Skip to content

Commit

Permalink
Merge pull request #43 from kumo-ai/yyuan/implement-sample-softmax
Browse files Browse the repository at this point in the history
Sample softmax for ContextGNN
  • Loading branch information
yiweny authored Nov 18, 2024
2 parents 9e3f913 + 2b51dc2 commit f62a240
Show file tree
Hide file tree
Showing 5 changed files with 403 additions and 34 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
## How to Run

We run our experiments on NVIDIA L40S Tensor Core GPU with 44.7 GB of memory.
If you want to run with smaller GPU memory, please set `num_layers=2` in all the scripts, unless you are running on `rel-trial` for RelBench, in this case, please use `num_layers=4`.
If you want to run ContextGNN and you have a smaller GPU machine, run `examples/contextgnn_sample_softmax.py`.

```sh
python examples/contextgnn_sample_softmax.py --rhs_sample_size 1000
```

To reproduce results on RelBench, run `benchmark/relbench_link_prediction_benchmark.py`.

Expand All @@ -15,7 +18,7 @@ To reproduce results on IJCAI-Contest, run `benchmark/tgt_ijcai_benchmark.py`.
python tgt_ijcai_benchmark.py --model contextgnn
```

To run ContextGNN without optuna tuning, run
To run ContextGNN normally, run

```sh
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model contextgnn
Expand Down
135 changes: 107 additions & 28 deletions contextgnn/nn/models/contextgnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Tuple, Type

import torch
from torch import Tensor
Expand All @@ -7,6 +7,7 @@
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType
from torch_geometric.utils.map import map_index
from typing_extensions import Self

from contextgnn.nn.encoder import (
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(
norm: str = 'layer_norm',
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
rhs_sample_size: Optional[int] = None,
) -> None:
super().__init__(data, col_stats_dict, rhs_emb_mode, dst_entity_table,
num_nodes, embedding_dim)
Expand Down Expand Up @@ -76,6 +78,8 @@ def __init__(
self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1)
self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1)
self.channels = channels
self.num_rhs_nodes = num_nodes
self.rhs_sample_size = rhs_sample_size

self.reset_parameters()

Expand All @@ -91,12 +95,67 @@ def reset_parameters(self) -> None:
self.lin_offset_idgnn.reset_parameters()
self.lhs_projector.reset_parameters()

def forward(
def sample_step(self, rhs_idgnn_index, lhs_idgnn_batch, rhs_gnn_embedding,
lhs_y_batch, rhs_y_index):
rnd = torch.rand(self.num_rhs_nodes, device=rhs_idgnn_index.device)
# Prioritize idgnn logits
rnd[rhs_idgnn_index] = 3.
# Ensure we always sample positives
rhs_y_index = rhs_y_index
assert rhs_y_index is not None # always pass in dst index
rnd[rhs_y_index] = 4.
rhs_index = rnd.topk(self.rhs_sample_size, sorted=True).indices
inclusive = rhs_y_index.numel() <= self.rhs_sample_size
rhs_y_index, mask = map_index(rhs_y_index, rhs_index,
max_index=self.num_rhs_nodes,
inclusive=inclusive)
lhs_y_batch = lhs_y_batch if inclusive else lhs_y_batch[mask]
rhs_embedding = self.rhs_embedding(rhs_index) # num_rhs_nodes, channel
inclusive = (rhs_y_index.numel() + rhs_idgnn_index.numel()
<= self.rhs_sample_size)
rhs_idgnn_index, mask = map_index(rhs_idgnn_index, rhs_index,
inclusive=inclusive)
if not inclusive:
lhs_idgnn_batch = lhs_idgnn_batch[mask]
rhs_gnn_embedding = rhs_gnn_embedding[mask]
return (rhs_idgnn_index, rhs_embedding, lhs_idgnn_batch,
rhs_gnn_embedding, lhs_y_batch, rhs_y_index)

def construct_logits(self, lhs_embedding_projected, lhs_embedding,
rhs_gnn_embedding, rhs_embedding, lhs_idgnn_batch,
rhs_idgnn_index):
embgnn_logits = lhs_embedding_projected @ rhs_embedding.t(
) # batch_size, num_rhs_nodes

# Model the importance of embedding-GNN prediction for each lhs node
embgnn_offset_logits = self.lin_offset_embgnn(
lhs_embedding_projected).flatten()
embgnn_logits += embgnn_offset_logits.view(-1, 1)

# Calculate idgnn logits
idgnn_logits = self.head(
rhs_gnn_embedding).flatten() # num_sampled_rhs
# Because we are only doing 2 hop, we are not really sampling info from
# lhs therefore, we need to incorporate this information using
# lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding
idgnn_logits += (
lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel
rhs_gnn_embedding).sum(
dim=-1).flatten() # num_sampled_rhs, channel

# Model the importance of ID-GNN prediction for each lhs node
idgnn_offset_logits = self.lin_offset_idgnn(
lhs_embedding_projected).flatten()
idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch]

embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits
return embgnn_logits

def forward_gnn(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
):
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)

Expand All @@ -113,6 +172,16 @@ def forward(
x_dict,
batch.edge_index_dict,
)
return x_dict

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
seed_time = batch[entity_table].seed_time
x_dict = self.forward_gnn(batch, entity_table)

batch_size = seed_time.size(0)
lhs_embedding = x_dict[entity_table][:
Expand All @@ -121,34 +190,44 @@ def forward(
rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel
rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs
lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size
rhs_embedding = self.rhs_embedding() # num_rhs_nodes, channel

embgnn_logits = lhs_embedding_projected @ rhs_embedding.t(
) # batch_size, num_rhs_nodes

# Model the importance of embedding-GNN prediction for each lhs node
embgnn_offset_logits = self.lin_offset_embgnn(
lhs_embedding_projected).flatten()
embgnn_logits += embgnn_offset_logits.view(-1, 1)
rhs_embedding = self.rhs_embedding() # num_rhs_nodes, channel
embgnn_logits = self.construct_logits(lhs_embedding_projected,
lhs_embedding, rhs_gnn_embedding,
rhs_embedding, lhs_idgnn_batch,
rhs_idgnn_index)
return embgnn_logits

# Calculate idgnn logits
idgnn_logits = self.head(
rhs_gnn_embedding).flatten() # num_sampled_rhs
# Because we are only doing 2 hop, we are not really sampling info from
# lhs therefore, we need to incorporate this information using
# lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding
idgnn_logits += (
lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel
rhs_gnn_embedding).sum(
dim=-1).flatten() # num_sampled_rhs, channel
def forward_sample_softmax(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
src_batch: Optional[Tensor] = None,
dst_index: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Forward function with RHS sample softmax."""
seed_time = batch[entity_table].seed_time
x_dict = self.forward_gnn(batch, entity_table)

# Model the importance of ID-GNN prediction for each lhs node
idgnn_offset_logits = self.lin_offset_idgnn(
lhs_embedding_projected).flatten()
idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch]
batch_size = seed_time.size(0)
lhs_embedding = x_dict[entity_table][:
batch_size] # batch_size, channel
lhs_embedding_projected = self.lhs_projector(lhs_embedding)
rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel
rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs
lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size

embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits
return embgnn_logits
(rhs_idgnn_index, rhs_embedding, lhs_idgnn_batch, rhs_gnn_embedding,
lhs_y_batch, rhs_y_index) = self.sample_step(rhs_idgnn_index,
lhs_idgnn_batch,
rhs_gnn_embedding,
src_batch, dst_index)
embgnn_logits = self.construct_logits(lhs_embedding_projected,
lhs_embedding, rhs_gnn_embedding,
rhs_embedding, lhs_idgnn_batch,
rhs_idgnn_index)
return embgnn_logits, lhs_y_batch, rhs_y_index

def to(self, *args, **kwargs) -> Self:
return super().to(*args, **kwargs)
Expand Down
12 changes: 9 additions & 3 deletions contextgnn/nn/rhs_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,23 @@ def reset_parameters(self) -> None:
child.reset_parameters()
self._cached_rhs_embedding = None

def forward(self) -> Tensor:
def forward(self, index: Optional[Tensor] = None) -> Tensor:
if not self.training:
if self._cached_rhs_embedding is not None:
return self._cached_rhs_embedding
outs = []
if self.lookup_embedding is not None:
outs.append(self.lookup_embedding.weight)
if index is None:
outs.append(self.lookup_embedding.weight)
else:
outs.append(self.lookup_embedding.weight[index, :])
if self.encoder is not None and self.projector is not None:
assert self._feat is not None

out = self.encoder(self._feat)[0]
if index is None:
out = self.encoder(self._feat)[0]
else:
out = self.encoder(self._feat[index])[0]
out = self.projector(out)
# fuse
out = torch.sum(out, dim=1)
Expand Down
Loading

0 comments on commit f62a240

Please sign in to comment.