Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 23, 2024
2 parents 4b76450 + d0aaa15 commit 64029ef
Showing 1 changed file with 19 additions and 26 deletions.
45 changes: 19 additions & 26 deletions examples/ijcai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn.functional as F
from relbench.modeling.loader import SparseTensor
from torch import Tensor
from torch.optim.lr_scheduler import ExponentialLR
from torch_frame import stype
from torch_frame.data import Dataset
from torch_geometric.data import HeteroData
Expand All @@ -31,17 +32,18 @@
choices=["contextgnn", "idgnn", "shallowrhsgnn"],
)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--temporal_strategy", type=str, default="last")
parser.add_argument("--max_steps_per_epoch", type=int, default=10)
parser.add_argument("--max_steps_per_epoch", type=int, default=2000)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--eval_k", type=int, default=10)
parser.add_argument("--gamma_rate", type=int, default=0.95)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

Expand Down Expand Up @@ -95,6 +97,7 @@ def calculate_hit_rate_on_sparse_target(pred: torch.Tensor,
crow_indices = target.crow_indices()
col_indices = target.col_indices()
values = target.values()
assert values is not None
# Iterate through each row and check if predictions match ground truth
hits = 0
num_rows = val_pred.shape[0]
Expand All @@ -103,6 +106,8 @@ def calculate_hit_rate_on_sparse_target(pred: torch.Tensor,
# Get the ground truth indices for this row
row_start = crow_indices[i].item()
row_end = crow_indices[i + 1].item()
assert isinstance(row_start, int)
assert isinstance(row_end, int)
dst_indices = col_indices[row_start:row_end]
bool_indices = values[row_start:row_end]
true_indices = dst_indices[bool_indices]
Expand Down Expand Up @@ -167,6 +172,8 @@ def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx):
create_edge(data, behavior, beh_idx, dst_entity_table,
torch.tensor(coo_mat.col, dtype=torch.long))

assert dst_nodes is not None

num_neighbors = [
int(args.num_neighbors // 2**i) for i in range(args.num_layers)
]
Expand Down Expand Up @@ -256,7 +263,7 @@ def create_edge(data, behavior, beh_idx, pkey_name, pkey_idx):
raise ValueError(f"Unsupported model type {args.model}.")

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

lr_scheduler = ExponentialLR(optimizer, gamma=args.gamma_rate)

def train() -> float:
model.train()
Expand Down Expand Up @@ -296,6 +303,7 @@ def train() -> float:
loss.backward()

optimizer.step()
lr_scheduler.step()

loss_accum += float(loss) * numel
count_accum += numel
Expand Down Expand Up @@ -338,18 +346,19 @@ def test(loader: NeighborLoader, desc: str, target = None) -> np.ndarray:
# randomly select num_item indices
batch_user = scores.shape[0]
num_item = scores.shape[1]
random_items = torch.randint(0, num_item, (batch_user, 100)) # Shape: (batch_user, 100)
random_items = torch.randint(0, num_item, (batch_user, 100)).to(scores.device) # Shape: (batch_user, 100)
for i in range(batch_size):
user_idx = batch[src_entity_table].n_id[i]
target_item = target[user_idx]
if target_item is not None:
random_items[i, 0] = target_item

selected_scores = scores[torch.arange(batch_user).unsqueeze(1), random_items] # Shape: (num_user, 100)

selected_scores = torch.gather(scores, 1, random_items)
import pdb
pdb.set_trace()
_, top_k_indices = torch.topk(selected_scores, args.eval_k, dim=1) # Shape: (num_user, args.eval_k)

pred_mini = random_items[top_k_indices.cpu().tolist()]
pred_mini = random_items[top_k_indices.tolist()]
else:
_, pred_mini = torch.topk(scores, k=args.eval_k, dim=1)
pred_list.append(pred_mini)
Expand All @@ -367,25 +376,9 @@ def test(loader: NeighborLoader, desc: str, target = None) -> np.ndarray:

for epoch in range(1, args.epochs + 1):
train_loss = train()
if epoch % args.eval_epochs_interval == 0:
val_pred = test(loader_dict["val"], desc="Val")
val_metrics[tune_metric] = calculate_hit_rate(val_pred, target_list)
#val_metrics[tune_metric] = calculate_hit_rate_on_sparse_target(
# val_pred, dst_nodes_dict['val'].to(device))
print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, "
f"Val metrics: {val_metrics}")

if val_metrics[tune_metric] >= best_val_metric:
best_val_metric = val_metrics[tune_metric]
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

assert state_dict is not None
model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = calculate_hit_rate(val_pred, target_list)
#val_metrics = calculate_hit_rate_on_sparse_target(val_pred,
# dst_nodes_dict['val'].to(device))
print(f"Best val metrics: {val_metrics}")
test_pred = test(loader_dict["test"], desc="Test", target=target_list)
test_metrics = calculate_hit_rate(test_pred, target_list)
print(f"Best test metrics: {test_metrics}")

test_pred = test(loader_dict["test"], desc="Test", target=target_list)
test_metrics = calculate_hit_rate(test_pred, target_list)
Expand Down

0 comments on commit 64029ef

Please sign in to comment.