-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GraphBolt] Fix gpu NegativeSampler
for seeds.
#7068
Conversation
To trigger regression tests:
|
@yxy235 If you ask for my review as well, it will be easier for me to keep track of what is changing when it comes to GPU GraphBolt support. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, suggested a minor improvement.
I was experimenting to see what is the best way to create such tensors, below, you can see what I did: The output of the code below on colab is as follows:
import torch
import torch.utils.benchmark as benchmark
def f(pos_num, neg_num, dtype=torch.bool, device="cuda:0"):
return torch.cat(
(
torch.ones(
pos_num,
dtype=dtype,
device=device,
),
torch.zeros(
neg_num,
dtype=dtype,
device=device,
),
),
)
def g(pos_num, neg_num, dtype=torch.bool, device="cuda:0"):
labels = torch.empty(pos_num + neg_num, dtype=dtype, device=device)
labels[:pos_num] = 1
labels[pos_num:] = 0
return labels
assert torch.equal(f(10, 20), g(10, 20))
N = 10000000
neg_factor = 2
stmt = f'f({N}, {N * neg_factor})'
f_timer = benchmark.Timer(stmt=stmt, setup='import torch', globals={'f': f})
g_timer = benchmark.Timer(stmt=stmt, setup='import torch', globals={'f': g})
f_timer.timeit(1000), g_timer.timeit(1000) |
My experiment and suggestion above is a nit, I just wanted to see what is the best way to do it. |
I see. I will change it later for better performance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor nit comments that don't need to be addressed for this PR. However, we might want to scan the whole code base and make similar improvements. I think such small improvements, when applied to the whole codebase, will make a meaningful difference in performance.
Description
seeds
,indexes
,labels
to GPU when sampling on GPU.NegativeSampler
on GPU.Checklist
Please feel free to remove inapplicable items for your PR.
Changes