diff --git a/examples/ijcai_example.py b/examples/ijcai_example.py index c85f513..1962b49 100644 --- a/examples/ijcai_example.py +++ b/examples/ijcai_example.py @@ -11,6 +11,7 @@ import torch.nn.functional as F from relbench.modeling.loader import SparseTensor from torch import Tensor +from torch.optim.lr_scheduler import ExponentialLR from torch_frame import stype from torch_frame.data import Dataset from torch_geometric.data import HeteroData @@ -31,7 +32,7 @@ choices=["contextgnn", "idgnn", "shallowrhsgnn"], ) parser.add_argument("--lr", type=float, default=0.001) -parser.add_argument("--epochs", type=int, default=1) +parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--eval_epochs_interval", type=int, default=1) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--channels", type=int, default=128) @@ -39,9 +40,10 @@ parser.add_argument("--num_layers", type=int, default=2) parser.add_argument("--num_neighbors", type=int, default=128) parser.add_argument("--temporal_strategy", type=str, default="last") -parser.add_argument("--max_steps_per_epoch", type=int, default=10) +parser.add_argument("--max_steps_per_epoch", type=int, default=2000) parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--eval_k", type=int, default=10) +parser.add_argument("--gamma_rate", type=int, default=0.95) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() @@ -95,6 +97,7 @@ def calculate_hit_rate_on_sparse_target(pred: torch.Tensor, crow_indices = target.crow_indices() col_indices = target.col_indices() values = target.values() + assert values is not None # Iterate through each row and check if predictions match ground truth hits = 0 num_rows = val_pred.shape[0] @@ -103,6 +106,8 @@ def calculate_hit_rate_on_sparse_target(pred: torch.Tensor, # Get the ground truth indices for this row row_start = crow_indices[i].item() row_end = crow_indices[i + 1].item() + assert isinstance(row_start, int) + assert isinstance(row_end, int) dst_indices = col_indices[row_start:row_end] bool_indices = values[row_start:row_end] true_indices = dst_indices[bool_indices] @@ -167,6 +172,8 @@ def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx): create_edge(data, behavior, beh_idx, dst_entity_table, torch.tensor(coo_mat.col, dtype=torch.long)) +assert dst_nodes is not None + num_neighbors = [ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] @@ -256,7 +263,7 @@ def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx): raise ValueError(f"Unsupported model type {args.model}.") optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - +lr_scheduler = ExponentialLR(optimizer, gamma=args.gamma_rate) def train() -> float: model.train() @@ -296,6 +303,7 @@ def train() -> float: loss.backward() optimizer.step() + lr_scheduler.step() loss_accum += float(loss) * numel count_accum += numel @@ -338,18 +346,19 @@ def test(loader: NeighborLoader, desc: str, target = None) -> np.ndarray: # randomly select num_item indices batch_user = scores.shape[0] num_item = scores.shape[1] - random_items = torch.randint(0, num_item, (batch_user, 100)) # Shape: (batch_user, 100) + random_items = torch.randint(0, num_item, (batch_user, 100)).to(scores.device) # Shape: (batch_user, 100) for i in range(batch_size): user_idx = batch[src_entity_table].n_id[i] target_item = target[user_idx] if target_item is not None: random_items[i, 0] = target_item - selected_scores = scores[torch.arange(batch_user).unsqueeze(1), random_items] # Shape: (num_user, 100) - + selected_scores = torch.gather(scores, 1, random_items) + import pdb + pdb.set_trace() _, top_k_indices = torch.topk(selected_scores, args.eval_k, dim=1) # Shape: (num_user, args.eval_k) - pred_mini = random_items[top_k_indices.cpu().tolist()] + pred_mini = random_items[top_k_indices.tolist()] else: _, pred_mini = torch.topk(scores, k=args.eval_k, dim=1) pred_list.append(pred_mini) @@ -367,25 +376,9 @@ def test(loader: NeighborLoader, desc: str, target = None) -> np.ndarray: for epoch in range(1, args.epochs + 1): train_loss = train() - if epoch % args.eval_epochs_interval == 0: - val_pred = test(loader_dict["val"], desc="Val") - val_metrics[tune_metric] = calculate_hit_rate(val_pred, target_list) - #val_metrics[tune_metric] = calculate_hit_rate_on_sparse_target( - # val_pred, dst_nodes_dict['val'].to(device)) - print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, " - f"Val metrics: {val_metrics}") - - if val_metrics[tune_metric] >= best_val_metric: - best_val_metric = val_metrics[tune_metric] - state_dict = {k: v.cpu() for k, v in model.state_dict().items()} - -assert state_dict is not None -model.load_state_dict(state_dict) -val_pred = test(loader_dict["val"], desc="Best val") -val_metrics = calculate_hit_rate(val_pred, target_list) -#val_metrics = calculate_hit_rate_on_sparse_target(val_pred, -# dst_nodes_dict['val'].to(device)) -print(f"Best val metrics: {val_metrics}") + test_pred = test(loader_dict["test"], desc="Test", target=target_list) + test_metrics = calculate_hit_rate(test_pred, target_list) + print(f"Best test metrics: {test_metrics}") test_pred = test(loader_dict["test"], desc="Test", target=target_list) test_metrics = calculate_hit_rate(test_pred, target_list)