Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions GraphNeuralNetworks/test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

Flux.testmode!(gnn)

test_gradients(gnn, g, x, rtol = 1e-5)
test_gradients(gnn, g, x, rtol = 1e-5, test_mooncake = false)

@testset "constructor with names" begin
m = GNNChain(GCNConv(din => d),
Expand Down Expand Up @@ -53,7 +53,7 @@

Flux.trainmode!(gnn)

test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4)
test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4, test_mooncake = false)
end
end

Expand Down
19 changes: 7 additions & 12 deletions GraphNeuralNetworks/test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ end
for g in TEST_GRAPHS
g = add_self_loops(g)
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
# Note: test_mooncake not enabled for ChebConv (Mooncake backward pass error)
test_gradients(l, g, g.x, rtol = RTOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = false)
end

@testset "bias=false" begin
Expand Down Expand Up @@ -198,8 +197,7 @@ end
l = GATv2Conv(D_IN => D_OUT, tanh; heads, concat, dropout=0)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (concat ? heads * D_OUT : D_OUT, g.num_nodes)
# Mooncake backward pass error for this layer on CI
test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -208,8 +206,7 @@ end
l = GATv2Conv((D_IN, ein) => D_OUT, add_self_loops = false, dropout=0)
g = GNNGraph(TEST_GRAPHS[1], edata = rand(Float32, ein, TEST_GRAPHS[1].num_edges))
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
# Mooncake backward pass error for this layer on CI
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
end

@testset "num params" begin
Expand Down Expand Up @@ -568,31 +565,30 @@ end
ein = 2
heads = 3
# used like in Kool et al., 2019
# Mooncake backward pass error for this layer on CI
l = TransformerConv(D_IN * heads => D_IN; heads, add_self_loops = true,
root_weight = false, ff_channels = 10, skip_connection = true,
batch_norm = false)
# batch_norm=false here for tests to pass; true in paper
for g in TEST_GRAPHS
g = GNNGraph(g, ndata = rand(Float32, D_IN * heads, g.num_nodes))
@test size(l(g, g.x)) == (D_IN * heads, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
end
# used like in Shi et al., 2021
l = TransformerConv((D_IN, ein) => D_IN; heads, gating = true,
bias_qkv = true)
for g in TEST_GRAPHS
g = GNNGraph(g, edata = rand(Float32, ein, g.num_edges))
@test size(l(g, g.x, g.e)) == (D_IN * heads, g.num_nodes)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
end
# test averaging heads
l = TransformerConv(D_IN => D_IN; heads, concat = false,
bias_root = false,
root_weight = false)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
end
end

Expand Down Expand Up @@ -620,8 +616,7 @@ end
l = DConv(D_IN => D_OUT, k)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
# Note: test_mooncake not enabled for DConv (Mooncake backward pass error)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = false)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions GraphNeuralNetworks/test/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@test u[:, [1]] ≈ sum(g.ndata.x[:, 1:n], dims = 2)
@test p(g).gdata.u == u

test_gradients(p, g, g.x, rtol = 1e-5)
test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -42,7 +42,7 @@ end
for i in 1:ng])

@test size(p(g, g.x)) == (chout, ng)
test_gradients(p, g, g.x, rtol = 1e-5)
test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE)
end
end

Expand Down
58 changes: 29 additions & 29 deletions GraphNeuralNetworks/test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ end
@test y === h
@test size(h) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH)
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)

# Test with custom activation function
custom_activation = tanh
Expand All @@ -45,9 +45,9 @@ end
# Test that outputs differ when using different activation functions
@test !isapprox(y, y_custom, rtol=RTOL_HIGH)
# with no initial state
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH)
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

@testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -61,9 +61,9 @@ end
@test layer isa GNNRecurrence
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol = RTOL_HIGH)
test_gradients(layer, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH)
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)

# Test with custom activation function
custom_activation = tanh
Expand All @@ -74,15 +74,15 @@ end
# Test that outputs differ when using different activation functions
@test !isapprox(y, y_custom, rtol = RTOL_HIGH)
# with no initial state
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH)
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)

# interplay with GNNChain
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW, test_mooncake = TEST_MOONCAKE)
end

@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -93,9 +93,9 @@ end
@test size(h) == (out_channel, g.num_nodes)
@test size(c) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
end

@testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -107,15 +107,15 @@ end
y = layer(g, x)
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)

# interplay with GNNChain
model = GNNChain(GConvLSTM(in_channel => out_channel, 2), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
end

@testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -125,9 +125,9 @@ end
@test y === h
@test size(h) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
end


Expand All @@ -140,15 +140,15 @@ end
y = layer(g, x)
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)

# interplay with GNNChain
model = GNNChain(GConvGRU(in_channel => out_channel, 2), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
end

@testitem "DCGRUCell" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -158,9 +158,9 @@ end
@test y === h
@test size(h) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
end

@testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -172,15 +172,15 @@ end
y = layer(g, x)
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)

# interplay with GNNChain
model = GNNChain(DCGRU(in_channel => out_channel, 2), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
end

@testitem "EvolveGCNOCell" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -189,9 +189,9 @@ end
y, state = cell(g, g.x)
@test size(y) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
end

@testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -203,15 +203,15 @@ end
y = layer(g, x)
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)

# interplay with GNNChain
model = GNNChain(EvolveGCNO(in_channel => out_channel), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
end

# @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin
Expand Down
24 changes: 10 additions & 14 deletions GraphNeuralNetworks/test/test_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,13 @@ function test_gradients(
end

if test_mooncake
# Mooncake gradient with respect to input, compared against Zygote.
# Mooncake gradient with respect to input via Flux integration, compared against Zygote.
loss_mc_x = (xs...) -> loss(f, graph, xs...)
# TODO error without `invokelatest` when using TestItemRunner
_cache_x = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_x, xs...)
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_x, loss_mc_x, xs...)
result = Flux.withgradient(loss_mc_x, Flux.AutoMooncake(), xs...)
y_mc = result.val # Extract value from NamedTuple
g_mc = result.grad # Extract gradients tuple
@assert isapprox(y, y_mc; rtol, atol)
for i in eachindex(xs)
@assert isapprox(g[i], g_mc[i+1]; rtol, atol)
end
check_equal_leaves(g, g_mc; rtol, atol)
end

if test_gpu
Expand All @@ -158,14 +156,12 @@ function test_gradients(
end

if test_mooncake
# Mooncake gradient with respect to f, compared against Zygote.
ps_mc, re_mc = Flux.destructure(f)
loss_mc_f = ps -> loss(re_mc(ps), graph, xs...)
_cache_f = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_f, ps_mc)
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_f, loss_mc_f, ps_mc)
# Mooncake gradient with respect to f via Flux integration, compared against Zygote.
result = Flux.withgradient(f -> loss(f, graph, xs...), Flux.AutoMooncake(), f)
y_mc = result.val # Extract value from NamedTuple
g_mc_result = result.grad # Extract gradients tuple
@assert isapprox(y, y_mc; rtol, atol)
g_mc_f = (re_mc(g_mc[2]),)
check_equal_leaves(g, g_mc_f; rtol, atol)
check_equal_leaves(g, g_mc_result; rtol, atol)
end

if test_gpu
Expand Down
Loading