diff --git a/convergence_curve_contextgnn.png b/convergence_curve_contextgnn.png new file mode 100644 index 0000000..2a3e34e Binary files /dev/null and b/convergence_curve_contextgnn.png differ diff --git a/convergence_curve_idgnn.png b/convergence_curve_idgnn.png new file mode 100644 index 0000000..1c17cf8 Binary files /dev/null and b/convergence_curve_idgnn.png differ diff --git a/convergence_curve_shallowrhsgnn.png b/convergence_curve_shallowrhsgnn.png new file mode 100644 index 0000000..6ef0108 Binary files /dev/null and b/convergence_curve_shallowrhsgnn.png differ diff --git a/examples/.relbench_example.py.swp b/examples/.relbench_example.py.swp new file mode 100644 index 0000000..a74ecdc Binary files /dev/null and b/examples/.relbench_example.py.swp differ diff --git a/examples/contextgnn_sample_softmax.py b/examples/contextgnn_sample_softmax.py index aadf85b..47212bb 100644 --- a/examples/contextgnn_sample_softmax.py +++ b/examples/contextgnn_sample_softmax.py @@ -35,16 +35,16 @@ from contextgnn.utils import GloveTextEmbedding, RHSEmbeddingMode parser = argparse.ArgumentParser() -parser.add_argument("--dataset", type=str, default="rel-trial") -parser.add_argument("--task", type=str, default="site-sponsor-run") +parser.add_argument("--dataset", type=str, default="rel-amazon") +parser.add_argument("--task", type=str, default="user-item-purchase") parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--eval_epochs_interval", type=int, default=1) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--channels", type=int, default=128) parser.add_argument("--aggr", type=str, default="sum") -parser.add_argument("--num_layers", type=int, default=4) -parser.add_argument("--num_neighbors", type=int, default=128) +parser.add_argument("--num_layers", type=int, default=6) +parser.add_argument("--num_neighbors", type=int, default=64) parser.add_argument("--temporal_strategy", type=str, default="last") parser.add_argument("--max_steps_per_epoch", type=int, default=2000) parser.add_argument("--num_workers", type=int, default=0) @@ -121,7 +121,7 @@ torch_frame_model_kwargs={ "channels": 128, "num_layers": 4, - }, rhs_sample_size=100).to(device) + }, rhs_sample_size=1000).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)