Skip to content

Commit

Permalink
so close to val done
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Oct 1, 2024
1 parent 5a42980 commit feb9318
Showing 1 changed file with 80 additions and 35 deletions.
115 changes: 80 additions & 35 deletions benchmark/data_relbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,22 @@ def train(
epoch: int,
task: RecommendationTask,
) -> float:
rowptr, col = train_data
rowptr, col, _ = train_data["train"]
model.train()
global total_optimization_steps
N = len(rowptr) - 1

# FIXME: enable full training
N = args.batch_size * 2 # for debug
idxlist = list(range(N))
np.random.shuffle(idxlist)
for start in tqdm(range(0, N, args.batch_size), desc=f"Epoch {epoch:3d}"):
for start in tqdm(range(0, N, args.batch_size), desc=f"train epoch {epoch:3d}"):
end = min(start + args.batch_size, N)
batch_size = end - start
lhs_index = torch.tensor(idxlist[start:end], dtype=torch.int64,
device=device)
# count = rowptr[lhs_index + 1] - rowptr[lhs_index]
# src_batch, arange = _batched_arange(count)
# dst_index = col[arange + rowptr[lhs_index][src_batch]]
lhs_index = torch.tensor(
idxlist[start:end],
dtype=torch.int64,
device=device,
)
src_batch, dst_index = get_rhs_index(lhs_index, rowptr, col)
# convert rowptr and col to a dense tensor of ones:
x = torch.zeros(
Expand All @@ -157,7 +158,6 @@ def train(
optimizer.step()
optimizer.zero_grad()
total_optimization_steps += 1
optimizer.zero_grad()
return loss.item()


Expand All @@ -179,31 +179,57 @@ def test(
stage: Literal["val", "test"],
) -> float:
model.eval()
rowptr, col = data_dict[stage]
# x is from the training set for the validation set,
# and from the training+validation set for the test set:
if stage == "val":
rowptr, col, edge_index = data_dict["train"]
test_edge_index = data_dict["val"]
elif stage == "test":
# Combine train and val set:
edge_index_train = data_dict["train"]
edge_index_val = data_dict["val"]
edge_index = torch.cat([edge_index_train, edge_index_val], dim=1)
edge_index = coalesce(edge_index)
rowptr = torch._convert_indices_from_coo_to_csr(
input=edge_index[0],
size=task.num_src_nodes,
)
col = edge_index[1]
else:
raise ValueError(f"Invalid stage: {stage}")

N = len(rowptr) - 1
idxlist = list(range(N))
pred_list: list[Tensor] = []
for start in tqdm(range(0, N, args.batch_size), desc=f"Epoch {epoch:3d}"):
for start in tqdm(range(0, N, args.batch_size), desc=f"{stage}: {epoch:3d}"):
end = min(start + args.batch_size, N)
batch_size = end - start
lhs_index = torch.tensor(idxlist[start:end], dtype=torch.int64,
device=device)
lhs_index = torch.tensor(
idxlist[start:end],
dtype=torch.int64,
device=device,
)
lhs_eval_mask = torch.isin(lhs_index, test_edge_index[0])
lhs_index = lhs_index[lhs_eval_mask]
if len(lhs_index) == 0:
continue

src_batch, dst_index = get_rhs_index(lhs_index, rowptr, col)
# convert rowptr and col to a dense tensor of ones:
# x_input is from the training set for the validation set,
# and from the training+validation set for the test set:
x_input = torch.zeros(
batch_size = len(lhs_index)
x = torch.zeros(
(batch_size, task.num_dst_nodes),
dtype=torch.float32,
device=device,
)
x_input[src_batch, dst_index] = 1.0
recon_x, _, _ = model(x_input)
x[src_batch, dst_index] = 1.0
recon_x, _, _ = model(x)
scores = torch.sigmoid(recon_x)
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
pred_list.append(pred_mini)

pred = torch.cat(pred_list, dim=0).cpu().numpy()
res = task.evaluate(pred, task.get_table(stage))
pred = torch.cat(pred_list, dim=0).cpu().numpy() # (37004, 10) but should be (37003, 10) whyyyy
res = task.evaluate(pred, task.get_table(stage)) # ValueError: The shape of pred must be (37003, 10), but (37004, 10) given.
return res[LINK_PREDICTION_METRIC]


Expand All @@ -215,6 +241,8 @@ def load_data_dict(
for split in ['train', 'val', 'test']:
split_df = task.get_table(split).df.drop(
columns=['timestamp']).explode(task.dst_entity_col)
# print summary of the df
print(split_df.info())
edge_index = torch.tensor(
[
split_df[task.src_entity_col].values,
Expand All @@ -223,12 +251,12 @@ def load_data_dict(
dtype=torch.int64,
device=device,
)
row, col = coalesce(edge_index)
edge_index = coalesce(edge_index)
rowptr = torch._convert_indices_from_coo_to_csr(
input=row,
input=edge_index[0],
size=task.num_src_nodes,
)
data_dict[split] = (rowptr, col)
data_dict[split] = (rowptr, edge_index[1], edge_index)
return data_dict


Expand Down Expand Up @@ -260,33 +288,50 @@ def main() -> None:
model = MultiVAE(p_dims).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3,
weight_decay=args.wd)
best_val_map = 0.0
for epoch in range(1, args.epochs + 1):
train_loss = train(
model,
optimizer,
data_dict["train"],
data_dict,
device,
args,
epoch,
task,
)
val_map = test(
model,
data_dict,
device,
args,
epoch,
task,
"val",
)
# val_map = test(
# model,
# data_dict["val"],
# device,
# args,
# epoch,
# task,
# "val",
# )
if val_map > best_val_map:
best_val_map = val_map
torch.save(
model.state_dict(),
f'vae_{args.dataset}_{args.task}.pt',
)

print(f'Epoch {epoch:3d}, '
f'train_loss {train_loss:4.2f}, '
# f'val_map {val_map:4.2f}'
)
f'val_map {val_map:4.2f}')

# TODO: test from saved model
# with open(args.result_path, 'rb') as f:
# model = torch.load(f)
test_map = test(
model,
data_dict,
device,
args,
epoch,
task,
"test",
)
print(f"val_map: {val_map}, test_map: {test_map}")


if __name__ == '__main__':
Expand Down

0 comments on commit feb9318

Please sign in to comment.