Skip to content

Commit

Permalink
Merge pull request #57 from MilkshakeForReal/input
Browse files Browse the repository at this point in the history
relax input size
  • Loading branch information
YichengDWu authored Jul 13, 2022
2 parents 7c96745 + fc8ab99 commit 9e2fb4a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ function (l::VMHConv)(x::AbstractArray, ps, st::NamedTuple)
end

function (l::VMHConv)(x::NamedTuple, ps, st::NamedTuple)
x = _flatten(x)
function message(xi, xj, e)
posi, posj = xi.x, xj.x
hi, hj = values(drop(xi, :x)), values(drop(xj, :x))
Expand Down Expand Up @@ -391,6 +392,7 @@ function MPPDEConv(ϕ::AbstractExplicitLayer, ψ::AbstractExplicitLayer;
end

function (l::MPPDEConv)(x::AbstractArray, ps, st::NamedTuple)
x = _flatten(x)
g = st.graph
num_nodes = g.num_nodes
num_edges = g.num_edges
Expand Down Expand Up @@ -519,7 +521,8 @@ function GNOConv(ch::Pair{Int, Int}, ϕ::AbstractExplicitLayer, activation = ide
GNOConv{bias, typeof(aggr)}(first(ch), last(ch), initialgraph, aggr, linear, ϕ)
end

function (l::GNOConv{bias})(x::AbstractMatrix, ps, st::NamedTuple) where {bias}
function (l::GNOConv{bias})(x::AbstractArray, ps, st::NamedTuple) where {bias}
x = _flatten(x)
g = st.graph
s = g.ndata
nkeys = keys(s)
Expand Down
9 changes: 8 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ copy(g::GNNGraph, kwarg...) = GNNGraph(g, kwarg...)
wrapgraph(g::GNNGraph) = () -> copy(g)
wrapgraph(f::Function) = f
Creater a function wrapper of the input graph.
Creater a function wrapper of the input graph.
"""
wrapgraph(g::GNNGraph) = () -> copy(g)
wrapgraph(f::Function) = f
Expand All @@ -33,3 +33,10 @@ function updategraph(st::NamedTuple, g::GNNGraph)
end
return st
end

@inline _flatten(x::AbstractMatrix) = x
@inline function _flatten(x::AbstractArray{T, N}) where {T, N}
s = size(x)
return reshape(x, s[1:(end - 2)]..., s[end - 1] * s[end])
end
@inline _flatten(x::NamedTuple) = map(d -> _flatten(d), x)
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ using SafeTestsets
ps, st = Lux.setup(rng, l)
y, st = l(h, ps, st)
@test size(y) == (7, gh.num_nodes)

h = randn(T, 5, g.num_nodes, 2)

ps, st = Lux.setup(rng, l)
y, st = l(h, ps, st)
@test size(y) == (7, gh.num_nodes)
end

@testset "Without theta" begin
Expand Down

0 comments on commit 9e2fb4a

Please sign in to comment.