Skip to content

Commit

Permalink
set 100 for test metrics calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 23, 2024
1 parent 138041d commit 4b76450
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions examples/ijcai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@
choices=["contextgnn", "idgnn", "shallowrhsgnn"],
)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--temporal_strategy", type=str, default="last")
parser.add_argument("--max_steps_per_epoch", type=int, default=2000)
parser.add_argument("--max_steps_per_epoch", type=int, default=10)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--eval_k", type=int, default=1)
parser.add_argument("--eval_k", type=int, default=10)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

Expand All @@ -60,7 +60,7 @@
data = HeteroData()


def calculate_hit_rate(pred: torch.Tensor, target: List[Optional[int]]):
def calculate_hit_rate(pred: torch.Tensor, target: List[Optional[int]], num_candidates=None):
r"""Calculates hit rate when pred is a tensor and target is a list
Args:
pred (torch.Tensor): Prediction tensor of size (num_entity,
Expand All @@ -69,6 +69,8 @@ def calculate_hit_rate(pred: torch.Tensor, target: List[Optional[int]]):
value is None if user doesn't have a next best action.
The value is the dst node id if there is a next best
action.
num_candidates(int, optional): The number of candidates to
calculate any metrics
"""
hits = 0
total = 0
Expand Down Expand Up @@ -312,7 +314,7 @@ def train() -> float:


@torch.no_grad()
def test(loader: NeighborLoader, desc: str) -> np.ndarray:
def test(loader: NeighborLoader, desc: str, target = None) -> np.ndarray:
model.eval()

pred_list: List[Tensor] = []
Expand All @@ -332,7 +334,24 @@ def test(loader: NeighborLoader, desc: str) -> np.ndarray:
else:
raise ValueError(f"Unsupported model type: {args.model}.")

_, pred_mini = torch.topk(scores, k=args.eval_k, dim=1)
if target is not None:
# randomly select num_item indices
batch_user = scores.shape[0]
num_item = scores.shape[1]
random_items = torch.randint(0, num_item, (batch_user, 100)) # Shape: (batch_user, 100)
for i in range(batch_size):
user_idx = batch[src_entity_table].n_id[i]
target_item = target[user_idx]
if target_item is not None:
random_items[i, 0] = target_item

selected_scores = scores[torch.arange(batch_user).unsqueeze(1), random_items] # Shape: (num_user, 100)

_, top_k_indices = torch.topk(selected_scores, args.eval_k, dim=1) # Shape: (num_user, args.eval_k)

pred_mini = random_items[top_k_indices.cpu().tolist()]
else:
_, pred_mini = torch.topk(scores, k=args.eval_k, dim=1)
pred_list.append(pred_mini)
pred = torch.cat(pred_list, dim=0)
return pred
Expand All @@ -342,29 +361,32 @@ def test(loader: NeighborLoader, desc: str) -> np.ndarray:
best_val_metric = 0
tune_metric = 'hr'
val_metrics = dict()

with open(osp.join(path, 'tst_int'), 'rb') as fs:
target_list = pickle.load(fs)

for epoch in range(1, args.epochs + 1):
train_loss = train()
if epoch % args.eval_epochs_interval == 0:
val_pred = test(loader_dict["val"], desc="Val")
val_metrics[tune_metric] = calculate_hit_rate_on_sparse_target(
val_pred, dst_nodes_dict['val'].to(device))
val_metrics[tune_metric] = calculate_hit_rate(val_pred, target_list)
#val_metrics[tune_metric] = calculate_hit_rate_on_sparse_target(
# val_pred, dst_nodes_dict['val'].to(device))
print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, "
f"Val metrics: {val_metrics}")

if val_metrics[tune_metric] > best_val_metric:
if val_metrics[tune_metric] >= best_val_metric:
best_val_metric = val_metrics[tune_metric]
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

assert state_dict is not None
model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = calculate_hit_rate_on_sparse_target(val_pred,
dst_nodes_dict['val'].to(device))
val_metrics = calculate_hit_rate(val_pred, target_list)
#val_metrics = calculate_hit_rate_on_sparse_target(val_pred,
# dst_nodes_dict['val'].to(device))
print(f"Best val metrics: {val_metrics}")

with open(osp.join(path, 'tst_int'), 'rb') as fs:
target_list = pickle.load(fs)

test_pred = test(loader_dict["test"], desc="Test")
test_pred = test(loader_dict["test"], desc="Test", target=target_list)
test_metrics = calculate_hit_rate(test_pred, target_list)
print(f"Best test metrics: {test_metrics}")

0 comments on commit 4b76450

Please sign in to comment.