-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update node2vec #58
Open
YueZhong-bio
wants to merge
36
commits into
awslabs:master
Choose a base branch
from
YueZhong-bio:link_predictor_zy
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
update node2vec #58
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
e36134c
full_graph_link_predictor
YueZhong-bio 1db25a1
Merge pull request #1 from seqRep/full_graph_link_predictor
YueZhong-bio 0b0120d
Merge branch 'master' into master
mufeili bfda28a
Create gcn_link_predictor.py
YueZhong-bio ced4414
Create sage_link_predictor.py
YueZhong-bio 2a5c8ad
Create test_link_prediction.py
YueZhong-bio 8ebcabd
Create full_graph_link_predictor.py
YueZhong-bio da93c7e
Create logger.py
YueZhong-bio 6bb813c
Add files via upload
YueZhong-bio 032f875
Add files via upload
YueZhong-bio b06fb19
Update README.md
YueZhong-bio 7930972
Update README.md
YueZhong-bio 1d310df
Update README.md
YueZhong-bio 3bec551
Delete full_graph_link_predictor.py
YueZhong-bio 9099236
Merge branch 'master' into link_predictor_zy
YueZhong-bio 28d292f
Merge branch 'master' into link_predictor_zy
mufeili f949043
Merge branch 'master' into link_predictor_zy
mufeili d005beb
Update
mufeili b86f0cf
Update
mufeili c3c0438
Fix
mufeili 2a744ac
Merge pull request #3 from YueZhong-bio/lmf
YueZhong-bio a2bde0f
Update (#5)
mufeili 5435ffd
Fix (#6)
mufeili 958d5b3
Try CI (#7)
mufeili fa27bf3
Merge branch 'master' into link_predictor_zy
mufeili b490b12
CI (#8)
mufeili 03eb4b1
Update full_graph_link_predictor.py
YueZhong-bio 53459ad
Update README.md
YueZhong-bio a518ea2
Update full_graph_link_predictor.py
YueZhong-bio 0a0fbf0
Update full_graph_link_predictor.py
YueZhong-bio 40750b0
Merge branch 'master' into link_predictor_zy
mufeili d9289f7
Merge branch 'master' into link_predictor_zy
YueZhong-bio ac3c394
Update README.md
YueZhong-bio cf92a7e
Update README.md
YueZhong-bio d390185
Add files via upload
YueZhong-bio cb31eb6
Update README.md
YueZhong-bio File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import argparse | ||
import dgl | ||
|
||
import torch | ||
from torch.nn import Embedding | ||
from torch.utils.data import DataLoader | ||
from torch_sparse import SparseTensor | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
from ogb.linkproppred import DglLinkPropPredDataset | ||
|
||
def save_embedding(model): | ||
torch.save(model.embedding.weight.data.cpu(), 'embedding.pt') | ||
|
||
EPS = 1e-15 | ||
|
||
|
||
class Node2Vec(torch.nn.Module): | ||
r"""The Node2Vec model from the | ||
`"node2vec: Scalable Feature Learning for Networks" | ||
<https://arxiv.org/abs/1607.00653>`_ paper where random walks of | ||
length :obj:`walk_length` are sampled in a given graph, and node embeddings | ||
are learned via negative sampling optimization. | ||
Args: | ||
data: The graph. | ||
edge_index (LongTensor): The edge indices. | ||
embedding_dim (int): The size of each embedding vector. | ||
walk_length (int): The walk length. | ||
context_size (int): The actual context size which is considered for | ||
positive samples. This parameter increases the effective sampling | ||
rate by reusing samples across different source nodes. | ||
walks_per_node (int, optional): The number of walks to sample for each | ||
node. (default: :obj:`1`) | ||
p (float, optional): Likelihood of immediately revisiting a node in the | ||
walk. (default: :obj:`1`) | ||
q (float, optional): Control parameter to interpolate between | ||
breadth-first strategy and depth-first strategy (default: :obj:`1`) | ||
num_negative_samples (int, optional): The number of negative samples to | ||
use for each positive sample. (default: :obj:`1`) | ||
num_nodes (int, optional): The number of nodes. (default: :obj:`None`) | ||
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the | ||
weight matrix will be sparse. (default: :obj:`False`) | ||
""" | ||
def __init__(self, data,edge_index, embedding_dim, walk_length, context_size, | ||
walks_per_node=1, p=1, q=1, num_negative_samples=1, | ||
num_nodes=None, sparse=False): | ||
super(Node2Vec, self).__init__() | ||
|
||
self.data = data | ||
N = num_nodes | ||
row, col = edge_index | ||
self.adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) | ||
self.adj = self.adj.to('cpu') | ||
|
||
assert walk_length >= context_size | ||
|
||
self.embedding_dim = embedding_dim | ||
self.walk_length = walk_length - 1 | ||
self.context_size = context_size | ||
self.walks_per_node = walks_per_node | ||
self.p = p | ||
self.q = q | ||
self.num_negative_samples = num_negative_samples | ||
|
||
self.embedding = Embedding(N, embedding_dim, sparse=sparse) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.embedding.reset_parameters() | ||
|
||
def forward(self, batch=None): | ||
"""Returns the embeddings for the nodes in :obj:`batch`.""" | ||
emb = self.embedding.weight | ||
return emb if batch is None else emb[batch] | ||
|
||
def loader(self, **kwargs): | ||
return DataLoader(range(self.adj.sparse_size(0)), | ||
collate_fn=self.sample, **kwargs) | ||
|
||
def pos_sample(self, batch): | ||
batch = batch.repeat(self.walks_per_node) | ||
seed = torch.cat([torch.LongTensor(batch)] * 1) | ||
rw = (dgl.sampling.random_walk(dgl.graph(self.data.edges()), seed, length=self.walk_length))[0] | ||
|
||
walks = [] | ||
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size | ||
for j in range(num_walks_per_rw): | ||
walks.append(rw[:, j:j + self.context_size]) | ||
|
||
return torch.cat(walks, dim=0) | ||
|
||
def neg_sample(self, batch): | ||
batch = batch.repeat(self.walks_per_node * self.num_negative_samples) | ||
|
||
rw = torch.randint(self.adj.sparse_size(0), | ||
(batch.size(0), self.walk_length)) | ||
rw = torch.cat([batch.view(-1, 1), rw], dim=-1) | ||
|
||
walks = [] | ||
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size | ||
for j in range(num_walks_per_rw): | ||
walks.append(rw[:, j:j + self.context_size]) | ||
return torch.cat(walks, dim=0) | ||
|
||
|
||
def sample(self, batch): | ||
if not isinstance(batch, torch.Tensor): | ||
batch = torch.tensor(batch) | ||
return self.pos_sample(batch), self.neg_sample(batch) | ||
|
||
def loss(self, pos_rw, neg_rw): | ||
r"""Computes the loss given positive and negative random walks.""" | ||
|
||
# Positive loss. | ||
start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous() | ||
|
||
h_start = self.embedding(start).view(pos_rw.size(0), 1, | ||
self.embedding_dim) | ||
h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1, | ||
self.embedding_dim) | ||
|
||
out = (h_start * h_rest).sum(dim=-1).view(-1) | ||
pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean() | ||
|
||
# Negative loss. | ||
start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous() | ||
|
||
h_start = self.embedding(start).view(neg_rw.size(0), 1, | ||
self.embedding_dim) | ||
h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1, | ||
self.embedding_dim) | ||
|
||
out = (h_start * h_rest).sum(dim=-1).view(-1) | ||
neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean() | ||
|
||
return pos_loss + neg_loss | ||
|
||
def test(self, train_z, train_y, test_z, test_y, solver='lbfgs', | ||
multi_class='auto', *args, **kwargs): | ||
r"""Evaluates latent space quality via a logistic regression downstream | ||
task.""" | ||
clf = LogisticRegression(solver=solver, multi_class=multi_class, *args, | ||
**kwargs).fit(train_z.detach().cpu().numpy(), | ||
train_y.detach().cpu().numpy()) | ||
return clf.score(test_z.detach().cpu().numpy(), | ||
test_y.detach().cpu().numpy()) | ||
|
||
def __repr__(self): | ||
return '{}({}, {})'.format(self.__class__.__name__, | ||
self.embedding.weight.size(0), | ||
self.embedding.weight.size(1)) | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='OGBL-PPA (Node2Vec)') | ||
parser.add_argument('--device', type=int, default=0) | ||
parser.add_argument('--embedding_dim', type=int, default=128) | ||
parser.add_argument('--walk_length', type=int, default=40) | ||
parser.add_argument('--context_size', type=int, default=20) | ||
parser.add_argument('--walks_per_node', type=int, default=10) | ||
parser.add_argument('--batch_size', type=int, default=256) | ||
parser.add_argument('--lr', type=float, default=0.01) | ||
parser.add_argument('--epochs', type=int, default=2) | ||
parser.add_argument('--log_steps', type=int, default=1) | ||
args = parser.parse_args() | ||
|
||
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' | ||
device = torch.device(device) | ||
|
||
dataset = DglLinkPropPredDataset(name='ogbl-ppa') | ||
data = dataset[0] | ||
edge_index=torch.stack((data.edges()[0],data.edges()[1]),dim=0) | ||
|
||
model = Node2Vec(data, edge_index, args.embedding_dim, args.walk_length, | ||
args.context_size, args.walks_per_node,num_nodes=data.number_of_nodes(), | ||
sparse=True).to(device) | ||
|
||
loader = model.loader(batch_size=args.batch_size, shuffle=True, | ||
num_workers=4) | ||
optimizer = torch.optim.SparseAdam(model.parameters(), lr=args.lr) | ||
|
||
model.train() | ||
for epoch in range(1, args.epochs + 1): | ||
for i, (pos_rw, neg_rw) in enumerate(loader): | ||
|
||
optimizer.zero_grad() | ||
loss = model.loss(pos_rw.to(device), neg_rw.to(device)) | ||
loss.backward() | ||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | ||
|
||
optimizer.step() | ||
|
||
if (i + 1) % args.log_steps == 0: | ||
print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, ' | ||
f'Loss: {loss:.4f}') | ||
|
||
if (i + 1) % 100 == 0: # Save model every 100 steps. | ||
save_embedding(model) | ||
save_embedding(model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No "be"