Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 24, 2024
1 parent 64029ef commit a5984cd
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions examples/ijcai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
choices=["contextgnn", "idgnn", "shallowrhsgnn"],
)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--channels", type=int, default=128)
Expand Down Expand Up @@ -137,6 +137,7 @@ def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx):

col_stats_dict = {}
dst_nodes = None
trnLabel = None
for i in range(len(behs)):
behavior = behs[i]
with open(osp.join(path, 'trn_' + behavior), 'rb') as fs:
Expand Down Expand Up @@ -166,6 +167,7 @@ def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx):
coo_mat = sp.coo_matrix(mat)
if behavior == 'buy':
dst_nodes = coo_mat
trnLabel = 1 * (mat != 0)
beh_idx = torch.arange(len(coo_mat.data), dtype=torch.long)
create_edge(data, behavior, beh_idx, src_entity_table,
torch.tensor(coo_mat.row, dtype=torch.long))
Expand All @@ -181,14 +183,13 @@ def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx):
loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict = {}
split_date: Dict[str, int] = {}
split_date['train'] = 1103
split_date['val'] = 1110
split_date['train'] = 1110
split_date['test'] = 1111

num_src_nodes = data[src_entity_table].num_nodes
num_dst_nodes = data[dst_entity_table].num_nodes

for split in ["train", "val", "test"]:
for split in ["train", "test"]:
dst_nodes_data = dst_nodes.data < split_date[split]
dst_nodes_dict[split] = torch.sparse_coo_tensor(
torch.stack([torch.tensor(dst_nodes.row),
Expand Down Expand Up @@ -296,8 +297,12 @@ def train() -> float:
loss = F.binary_cross_entropy_with_logits(out, target)
numel = out.numel()
elif args.model in ['contextgnn', 'shallowrhsgnn']:
sampled_dst = torch.unique(dst_index)
logits = model(batch, src_entity_table, dst_entity_table)
logits = logits[:, sampled_dst]
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
import pdb
pdb.set_trace()
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[dst_entity_table].batch)
loss.backward()
Expand Down Expand Up @@ -346,6 +351,11 @@ def test(loader: NeighborLoader, desc: str, target = None) -> np.ndarray:
# randomly select num_item indices
batch_user = scores.shape[0]
num_item = scores.shape[1]
trnLabel[batch_user]
import pdb
# full set is the sampled rhs set
# you pick 99 items from the sampled rhs set.
pdb.set_trace()
random_items = torch.randint(0, num_item, (batch_user, 100)).to(scores.device) # Shape: (batch_user, 100)
for i in range(batch_size):
user_idx = batch[src_entity_table].n_id[i]
Expand All @@ -354,8 +364,6 @@ def test(loader: NeighborLoader, desc: str, target = None) -> np.ndarray:
random_items[i, 0] = target_item

selected_scores = torch.gather(scores, 1, random_items)
import pdb
pdb.set_trace()
_, top_k_indices = torch.topk(selected_scores, args.eval_k, dim=1) # Shape: (num_user, args.eval_k)

pred_mini = random_items[top_k_indices.tolist()]
Expand Down

0 comments on commit a5984cd

Please sign in to comment.