Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Nov 17, 2024
1 parent 8c1cd51 commit 7aebd8a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 34 deletions.
100 changes: 70 additions & 30 deletions contextgnn/nn/models/contextgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,40 @@ def sample_step(self, rhs_idgnn_index, lhs_idgnn_batch, rhs_gnn_embedding,
rhs_gnn_embedding = rhs_gnn_embedding[mask]
return rhs_embedding, lhs_idgnn_batch, rhs_gnn_embedding, rhs_y_index

def common(self, lhs_embedding_projected, rhs_embedding, rhs_gnn_embedding,
lhs_embedding, lhs_idgnn_batch, rhs_idgnn_index):
embgnn_logits = lhs_embedding_projected @ rhs_embedding.t(
) # batch_size, num_rhs_nodes

# Model the importance of embedding-GNN prediction for each lhs node
embgnn_offset_logits = self.lin_offset_embgnn(
lhs_embedding_projected).flatten()
embgnn_logits += embgnn_offset_logits.view(-1, 1)

# Calculate idgnn logits
idgnn_logits = self.head(
rhs_gnn_embedding).flatten() # num_sampled_rhs
# Because we are only doing 2 hop, we are not really sampling info from
# lhs therefore, we need to incorporate this information using
# lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding
idgnn_logits += (
lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel
rhs_gnn_embedding).sum(
dim=-1).flatten() # num_sampled_rhs, channel

# Model the importance of ID-GNN prediction for each lhs node
idgnn_offset_logits = self.lin_offset_idgnn(
lhs_embedding_projected).flatten()
idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch]

embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits
return embgnn_logits

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
src_batch: Optional[Tensor] = None,
dst_index: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
Expand All @@ -153,39 +180,52 @@ def forward(
rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs
lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size

if self.rhs_sample_size is not None and self.training:
(rhs_embedding, lhs_idgnn_batch, rhs_gnn_embedding,
rhs_y_index) = self.sample_step(rhs_idgnn_index, lhs_idgnn_batch,
rhs_gnn_embedding, src_batch,
dst_index)
else:
rhs_embedding = self.rhs_embedding() # num_rhs_nodes, channel
rhs_embedding = self.rhs_embedding() # num_rhs_nodes, channel
embgnn_logits = self.common(lhs_embedding_projected, rhs_embedding,
rhs_gnn_embedding, lhs_embedding,
lhs_idgnn_batch, rhs_idgnn_index)
return embgnn_logits

embgnn_logits = lhs_embedding_projected @ rhs_embedding.t(
) # batch_size, num_rhs_nodes
def forward_sample_softmax(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
src_batch: Optional[Tensor] = None,
dst_index: Optional[Tensor] = None,
):
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)

# Model the importance of embedding-GNN prediction for each lhs node
embgnn_offset_logits = self.lin_offset_embgnn(
lhs_embedding_projected).flatten()
embgnn_logits += embgnn_offset_logits.view(-1, 1)
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

# Calculate idgnn logits
idgnn_logits = self.head(
rhs_gnn_embedding).flatten() # num_sampled_rhs
# Because we are only doing 2 hop, we are not really sampling info from
# lhs therefore, we need to incorporate this information using
# lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding
idgnn_logits += (
lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel
rhs_gnn_embedding).sum(
dim=-1).flatten() # num_sampled_rhs, channel
for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

# Model the importance of ID-GNN prediction for each lhs node
idgnn_offset_logits = self.lin_offset_idgnn(
lhs_embedding_projected).flatten()
idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch]
x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)

embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits
batch_size = seed_time.size(0)
lhs_embedding = x_dict[entity_table][:
batch_size] # batch_size, channel
lhs_embedding_projected = self.lhs_projector(lhs_embedding)
rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel
rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs
lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size

(rhs_embedding, lhs_idgnn_batch, rhs_gnn_embedding,
rhs_y_index) = self.sample_step(rhs_idgnn_index, lhs_idgnn_batch,
rhs_gnn_embedding, src_batch,
dst_index)
embgnn_logits = self.common(lhs_embedding_projected, rhs_embedding,
rhs_gnn_embedding, lhs_embedding,
lhs_idgnn_batch, rhs_idgnn_index)
return embgnn_logits, src_batch, rhs_y_index

def to(self, *args, **kwargs) -> Self:
Expand Down
7 changes: 3 additions & 4 deletions examples/sample_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,9 @@ def train() -> float:
loss = F.binary_cross_entropy_with_logits(out, target)
numel = out.numel()
elif args.model in ['contextgnn', 'shallowrhsgnn']:
logits, lhs_y_batch, rhs_y_index = model(batch,
task.src_entity_table,
task.dst_entity_table,
src_batch, dst_index)
logits, lhs_y_batch, rhs_y_index = model.forward_sample_softmax(
batch, task.src_entity_table, task.dst_entity_table, src_batch,
dst_index)
edge_label_index = torch.stack([lhs_y_batch, rhs_y_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
Expand Down

0 comments on commit 7aebd8a

Please sign in to comment.