Skip to content

Commit

Permalink
reduce the number of users
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Nov 6, 2024
1 parent 582d5ee commit 253515e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions benchmark/tgt_ijcai_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@
parser.add_argument(
"--num_repeats", type=int, default=2,
help="Number of repeated training and eval on the best config.")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=6)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--num_neighbors", type=int, default=64)
parser.add_argument("--temporal_strategy", type=str, default="last")
parser.add_argument("--max_steps_per_epoch", type=int, default=100)
parser.add_argument("--num_workers", type=int, default=0)
Expand Down Expand Up @@ -354,6 +354,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, desc: str, stage: str,
assert trnLabel is not None
pos_item_per_user = trnLabel[user_idx].coalesce().indices(
).reshape(-1)
pos_item_per_user = pos_item_per_user.to(device)

indices = torch.randint(0, all_sampled_rhs.size(0), (1000, ))
sampled_items = all_sampled_rhs[indices]
Expand Down

0 comments on commit 253515e

Please sign in to comment.