diff --git a/examples/graphbolt/pyg/hetero/node_classification.py b/examples/graphbolt/pyg/hetero/node_classification.py index 4166805c3f92..fb46a0ade970 100644 --- a/examples/graphbolt/pyg/hetero/node_classification.py +++ b/examples/graphbolt/pyg/hetero/node_classification.py @@ -65,7 +65,11 @@ def create_dataloader( node_feature_keys["institute"] = ["feat"] node_feature_keys["fos"] = ["feat"] # Fetch node features for the sampled subgraph. - datapipe = datapipe.fetch_feature(features, node_feature_keys) + datapipe = datapipe.fetch_feature( + features, + node_feature_keys, + overlap_fetch=args.overlap_feature_fetch, + ) # Copy the data to the specified device. if need_copy: diff --git a/examples/graphbolt/pyg/multigpu/node_classification.py b/examples/graphbolt/pyg/multigpu/node_classification.py index d2884b6a87f9..30e076ce50f8 100644 --- a/examples/graphbolt/pyg/multigpu/node_classification.py +++ b/examples/graphbolt/pyg/multigpu/node_classification.py @@ -199,7 +199,7 @@ def weighted_reduce(tensor, weight, dst=0): @torch.compile -def train_step(minibatch, optimizer, model, loss_fn, cooperative): +def train_step(minibatch, optimizer, model, loss_fn): node_features = minibatch.node_features["feat"] labels = minibatch.labels optimizer.zero_grad() @@ -211,9 +211,7 @@ def train_step(minibatch, optimizer, model, loss_fn, cooperative): return loss.detach(), num_correct, labels.size(0) -def train_helper( - rank, dataloader, model, optimizer, loss_fn, device, cooperative -): +def train_helper(rank, dataloader, model, optimizer, loss_fn, device): model.train() # Set the model to training mode total_loss = torch.zeros(1, device=device) # Accumulator for the total loss # Accumulator for the total number of correct predictions @@ -223,7 +221,7 @@ def train_helper( start = time.time() for minibatch in tqdm(dataloader, "Training") if rank == 0 else dataloader: loss, num_correct, num_samples = train_step( - minibatch, optimizer, model, loss_fn, cooperative + minibatch, optimizer, model, loss_fn ) total_loss += loss total_correct += num_correct @@ -263,7 +261,6 @@ def train(args, rank, train_dataloader, valid_dataloader, model, device): optimizer, loss_fn, device, - args.cooperative, ) val_acc = evaluate(rank, model, valid_dataloader, device) if rank == 0: @@ -381,7 +378,7 @@ def parse_args(): default=1, help="The number of accesses after which a vertex neighborhood will be cached.", ) - parser.add_argument("--precision", type=str, default="high") + parser.add_argument("--precision", type=str, default="medium") parser.add_argument( "--cooperative", action="store_true",