Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 18, 2024
1 parent 6ddedbd commit 5856967
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions examples/ijcai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@
choices=["contextgnn", "idgnn", "shallowrhsgnn"],
)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--temporal_strategy", type=str, default="last")
parser.add_argument("--max_steps_per_epoch", type=int, default=2000)
parser.add_argument("--max_steps_per_epoch", type=int, default=10)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--eval_k", type=int, default=10)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

Expand All @@ -59,6 +60,26 @@
data = HeteroData()


def calculate_hit_rate(pred, target):
"""
Calculate the Hit Rate (HR) given predicted and target values.
Args:
pred (np.ndarray): 2D numpy array of shape (num_users, num_preds), predicted values.
target (np.ndarray): 2D numpy array of shape (num_users, any_value), target values.
Returns:
float: Hit Rate (HR) as a ratio of users with hits to the total number of users.
"""
# Check if any of the predictions for each user match any of their target values
hits = np.any(np.isin(pred, target), axis=1)

# Calculate the hit rate as the ratio of users with at least one hit
hit_rate = np.mean(hits)

return hit_rate


def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx):
# fkey -> pkey edges
edge_index = torch.stack([beh_idx, pkey_idx], dim=0)
Expand Down Expand Up @@ -114,7 +135,7 @@ def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx):
]

loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {}
dst_nodes_dict = {}
split_date: Dict[str, int] = {}
split_date['train'] = 1103
split_date['val'] = 1110
Expand Down Expand Up @@ -276,21 +297,22 @@ def test(loader: NeighborLoader, desc: str) -> np.ndarray:
else:
raise ValueError(f"Unsupported model type: {args.model}.")

_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
_, pred_mini = torch.topk(scores, k=args.eval_k, dim=1)
pred_list.append(pred_mini)
pred = torch.cat(pred_list, dim=0).cpu().numpy()
return pred


state_dict = None
best_val_metric = 0
tune_metric = 'hr'
for epoch in range(1, args.epochs + 1):
train_loss = train()
if epoch % args.eval_epochs_interval == 0:
val_pred = test(loader_dict["val"], desc="Val")
import pdb
pdb.set_trace()
val_metrics = task.evaluate(val_pred, task.get_table("val"))
val_metrics = calculate_hit_rate(val_pred, dst_nodes_dict['val'].to_dense().numpy())
print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, "
f"Val metrics: {val_metrics}")

Expand All @@ -301,9 +323,14 @@ def test(loader: NeighborLoader, desc: str) -> np.ndarray:
assert state_dict is not None
model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
val_metrics = calculate_hit_rate(val_pred, dst_nodes_dict['val'].to_dense().numpy())
print(f"Best val metrics: {val_metrics}")

with open(osp.join(path, 'tst_int'), 'rb') as fs:
mat = pickle.load(fs)
import pdb
pdb.set_trace()

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

0 comments on commit 5856967

Please sign in to comment.