Skip to content

Commit

Permalink
Merge pull request #49 from MilkshakeForReal/edge-features
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
YichengDWu authored Jul 4, 2022
2 parents f0edb41 + 774ff88 commit 5201e6b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -524,17 +524,17 @@ function (l::GNOConv{bias})(x::AbstractMatrix, ps, st::NamedTuple) where {bias}
si, sj = xi[nkeys], xj[nkeys]
si, sj = reduce(vcat, values(si), init = initarray),
reduce(vcat, values(sj), init = initarray)

e = reduce(vcat, values(e), init = initarray)

W, st_ϕ = l.ϕ(vcat(si, sj, e), ps.ϕ, st.ϕ)
st = merge(st, (; ϕ = st_ϕ))

hj = xj.h_
nin, nedges = size(hj)
W = reshape(W, :, nin, nedges)
hj = reshape(hj, (nin, 1, nedges))
W = reshape(W, :, l.in_chs, num_edges)
hj = reshape(hj, l.in_chs, 1, num_edges)
m = NNlib.batched_mul(W, hj)
return reshape(m, :, nedges)
return reshape(m, :, num_edges)
end

xs = merge((; h_ = x), s)
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ import Base: copy
@inline drop(nt::NamedTuple, key::Symbol) = Base.structdiff(nt, NamedTuple{(key,)})

"""
copy(g::GNNGraph)
copy(g::GNNGraph, kwarg...)
Create a shollow copy of the input graph `g`. This is equivalent to `GNNGraph(g)`.
"""
copy(g::GNNGraph) = GNNGraph(g)
copy(g::GNNGraph, kwarg...) = GNNGraph(g, kwarg...)

@doc raw"""
wrapgraph(g::GNNGraph) = () -> copy(g)
Expand Down
11 changes: 5 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,21 @@ import Flux: batch, unbatch
ps, st = Lux.setup(rng, l)

y, st = l(h, ps, st)
@test size(y) == (7, g.num_nodes)
@test size(y) == (out_chs, g.num_nodes)

l = GNOConv(5 => 7, ϕ, initialgraph = g, bias = false)
l = GNOConv(in_chs => out_chs, ϕ, initialgraph = g)
rng = Random.default_rng()
ps, st = Lux.setup(rng, l)

y, st = l(h, ps, st)
@test size(y) == (7, g.num_nodes)
@test size(y) == (out_chs, g.num_nodes)

e = rand(2 + 2 + 3 + 3, g.num_edges)
g = GNNGraph(edge_index(g), edata = e)
g = GNNGraph(g, ndata = NamedTuple(), edata = rand(2 + 2 + 3 + 3, g.num_edges))

st = updategraph(st, g)
y, st = l(h, ps, st)

@test size(y) == (7, g.num_nodes)
@test size(y) == (out_chs, g.num_nodes)
end
end
end
Expand Down

0 comments on commit 5201e6b

Please sign in to comment.