diff --git a/GraphNeuralNetworks/test/layers/basic.jl b/GraphNeuralNetworks/test/layers/basic.jl index bce86b75d..47b111535 100644 --- a/GraphNeuralNetworks/test/layers/basic.jl +++ b/GraphNeuralNetworks/test/layers/basic.jl @@ -18,7 +18,7 @@ Flux.testmode!(gnn) - test_gradients(gnn, g, x, rtol = 1e-5) + test_gradients(gnn, g, x, rtol = 1e-5, test_mooncake = false) @testset "constructor with names" begin m = GNNChain(GCNConv(din => d), @@ -53,7 +53,7 @@ Flux.trainmode!(gnn) - test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4) + test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4, test_mooncake = false) end end diff --git a/GraphNeuralNetworks/test/layers/conv.jl b/GraphNeuralNetworks/test/layers/conv.jl index ecc0fbd78..96137a3ce 100644 --- a/GraphNeuralNetworks/test/layers/conv.jl +++ b/GraphNeuralNetworks/test/layers/conv.jl @@ -86,8 +86,7 @@ end for g in TEST_GRAPHS g = add_self_loops(g) @test size(l(g, g.x)) == (D_OUT, g.num_nodes) - # Note: test_mooncake not enabled for ChebConv (Mooncake backward pass error) - test_gradients(l, g, g.x, rtol = RTOL_LOW) + test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = false) end @testset "bias=false" begin @@ -198,8 +197,7 @@ end l = GATv2Conv(D_IN => D_OUT, tanh; heads, concat, dropout=0) for g in TEST_GRAPHS @test size(l(g, g.x)) == (concat ? heads * D_OUT : D_OUT, g.num_nodes) - # Mooncake backward pass error for this layer on CI - test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW) + test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE) end end @@ -208,8 +206,7 @@ end l = GATv2Conv((D_IN, ein) => D_OUT, add_self_loops = false, dropout=0) g = GNNGraph(TEST_GRAPHS[1], edata = rand(Float32, ein, TEST_GRAPHS[1].num_edges)) @test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes) - # Mooncake backward pass error for this layer on CI - test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW) + test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE) end @testset "num params" begin @@ -568,7 +565,6 @@ end ein = 2 heads = 3 # used like in Kool et al., 2019 - # Mooncake backward pass error for this layer on CI l = TransformerConv(D_IN * heads => D_IN; heads, add_self_loops = true, root_weight = false, ff_channels = 10, skip_connection = true, batch_norm = false) @@ -576,7 +572,7 @@ end for g in TEST_GRAPHS g = GNNGraph(g, ndata = rand(Float32, D_IN * heads, g.num_nodes)) @test size(l(g, g.x)) == (D_IN * heads, g.num_nodes) - test_gradients(l, g, g.x, rtol = RTOL_LOW) + test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE) end # used like in Shi et al., 2021 l = TransformerConv((D_IN, ein) => D_IN; heads, gating = true, @@ -584,7 +580,7 @@ end for g in TEST_GRAPHS g = GNNGraph(g, edata = rand(Float32, ein, g.num_edges)) @test size(l(g, g.x, g.e)) == (D_IN * heads, g.num_nodes) - test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW) + test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE) end # test averaging heads l = TransformerConv(D_IN => D_IN; heads, concat = false, @@ -592,7 +588,7 @@ end root_weight = false) for g in TEST_GRAPHS @test size(l(g, g.x)) == (D_IN, g.num_nodes) - test_gradients(l, g, g.x, rtol = RTOL_LOW) + test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE) end end @@ -620,8 +616,7 @@ end l = DConv(D_IN => D_OUT, k) for g in TEST_GRAPHS @test size(l(g, g.x)) == (D_OUT, g.num_nodes) - # Note: test_mooncake not enabled for DConv (Mooncake backward pass error) - test_gradients(l, g, g.x, rtol = RTOL_HIGH) + test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = false) end end end diff --git a/GraphNeuralNetworks/test/layers/pool.jl b/GraphNeuralNetworks/test/layers/pool.jl index fa1475b20..7300b9aa3 100644 --- a/GraphNeuralNetworks/test/layers/pool.jl +++ b/GraphNeuralNetworks/test/layers/pool.jl @@ -19,7 +19,7 @@ @test u[:, [1]] ≈ sum(g.ndata.x[:, 1:n], dims = 2) @test p(g).gdata.u == u - test_gradients(p, g, g.x, rtol = 1e-5) + test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE) end end @@ -42,7 +42,7 @@ end for i in 1:ng]) @test size(p(g, g.x)) == (chout, ng) - test_gradients(p, g, g.x, rtol = 1e-5) + test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE) end end diff --git a/GraphNeuralNetworks/test/layers/temporalconv.jl b/GraphNeuralNetworks/test/layers/temporalconv.jl index 93d6b0082..6489cba98 100644 --- a/GraphNeuralNetworks/test/layers/temporalconv.jl +++ b/GraphNeuralNetworks/test/layers/temporalconv.jl @@ -32,9 +32,9 @@ end @test y === h @test size(h) == (out_channel, g.num_nodes) # with no initial state - test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH) + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE) # with initial state - test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH) + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE) # Test with custom activation function custom_activation = tanh @@ -45,9 +45,9 @@ end # Test that outputs differ when using different activation functions @test !isapprox(y, y_custom, rtol=RTOL_HIGH) # with no initial state - test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH) + test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE) # with initial state - test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH) + test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE) end @testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin @@ -61,9 +61,9 @@ end @test layer isa GNNRecurrence @test size(y) == (out_channel, timesteps, g.num_nodes) # with no initial state - test_gradients(layer, g, x, rtol = RTOL_HIGH) + test_gradients(layer, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE) # with initial state - test_gradients(layer, g, x, state0, rtol = RTOL_HIGH) + test_gradients(layer, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE) # Test with custom activation function custom_activation = tanh @@ -74,15 +74,15 @@ end # Test that outputs differ when using different activation functions @test !isapprox(y, y_custom, rtol = RTOL_HIGH) # with no initial state - test_gradients(layer_custom, g, x, rtol = RTOL_HIGH) + test_gradients(layer_custom, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE) # with initial state - test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH) + test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE) # interplay with GNNChain model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) y = model(g, x) @test size(y) == (1, timesteps, g.num_nodes) - test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW) + test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW, test_mooncake = TEST_MOONCAKE) end @testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin @@ -93,9 +93,9 @@ end @test size(h) == (out_channel, g.num_nodes) @test size(c) == (out_channel, g.num_nodes) # with no initial state - test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # with initial state - test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) end @testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin @@ -107,15 +107,15 @@ end y = layer(g, x) @test size(y) == (out_channel, timesteps, g.num_nodes) # with no initial state - test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # with initial state - test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # interplay with GNNChain model = GNNChain(GConvLSTM(in_channel => out_channel, 2), Dense(out_channel, 1)) y = model(g, x) @test size(y) == (1, timesteps, g.num_nodes) - test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false) end @testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin @@ -125,9 +125,9 @@ end @test y === h @test size(h) == (out_channel, g.num_nodes) # with no initial state - test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # with initial state - test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) end @@ -140,15 +140,15 @@ end y = layer(g, x) @test size(y) == (out_channel, timesteps, g.num_nodes) # with no initial state - test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # with initial state - test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # interplay with GNNChain model = GNNChain(GConvGRU(in_channel => out_channel, 2), Dense(out_channel, 1)) y = model(g, x) @test size(y) == (1, timesteps, g.num_nodes) - test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false) end @testitem "DCGRUCell" setup=[TemporalConvTestModule, TestModule] begin @@ -158,9 +158,9 @@ end @test y === h @test size(h) == (out_channel, g.num_nodes) # with no initial state - test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # with initial state - test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) end @testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin @@ -172,15 +172,15 @@ end y = layer(g, x) @test size(y) == (out_channel, timesteps, g.num_nodes) # with no initial state - test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # with initial state - test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # interplay with GNNChain model = GNNChain(DCGRU(in_channel => out_channel, 2), Dense(out_channel, 1)) y = model(g, x) @test size(y) == (1, timesteps, g.num_nodes) - test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false) end @testitem "EvolveGCNOCell" setup=[TemporalConvTestModule, TestModule] begin @@ -189,9 +189,9 @@ end y, state = cell(g, g.x) @test size(y) == (out_channel, g.num_nodes) # with no initial state - test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) # with initial state - test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false) end @testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin @@ -203,15 +203,15 @@ end y = layer(g, x) @test size(y) == (out_channel, timesteps, g.num_nodes) # with no initial state - test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE) # with initial state - test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE) # interplay with GNNChain model = GNNChain(EvolveGCNO(in_channel => out_channel), Dense(out_channel, 1)) y = model(g, x) @test size(y) == (1, timesteps, g.num_nodes) - test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE) end # @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin diff --git a/GraphNeuralNetworks/test/test_module.jl b/GraphNeuralNetworks/test/test_module.jl index 15fb7920c..450bebe5a 100644 --- a/GraphNeuralNetworks/test/test_module.jl +++ b/GraphNeuralNetworks/test/test_module.jl @@ -123,15 +123,11 @@ function test_gradients( end if test_mooncake - # Mooncake gradient with respect to input, compared against Zygote. + # Mooncake gradient with respect to input via Flux integration, compared against Zygote. loss_mc_x = (xs...) -> loss(f, graph, xs...) - # TODO error without `invokelatest` when using TestItemRunner - _cache_x = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_x, xs...) - y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_x, loss_mc_x, xs...) + y_mc, g_mc = Flux.withgradient(loss_mc_x, Flux.AutoMooncake(), xs...) @assert isapprox(y, y_mc; rtol, atol) - for i in eachindex(xs) - @assert isapprox(g[i], g_mc[i+1]; rtol, atol) - end + check_equal_leaves(g, g_mc; rtol, atol) end if test_gpu @@ -158,14 +154,10 @@ function test_gradients( end if test_mooncake - # Mooncake gradient with respect to f, compared against Zygote. - ps_mc, re_mc = Flux.destructure(f) - loss_mc_f = ps -> loss(re_mc(ps), graph, xs...) - _cache_f = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_f, ps_mc) - y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_f, loss_mc_f, ps_mc) + # Mooncake gradient with respect to f via Flux integration, compared against Zygote. + y_mc, g_mc = Flux.withgradient(f -> loss(f, graph, xs...), Flux.AutoMooncake(), f) @assert isapprox(y, y_mc; rtol, atol) - g_mc_f = (re_mc(g_mc[2]),) - check_equal_leaves(g, g_mc_f; rtol, atol) + check_equal_leaves(g, g_mc_result; rtol, atol) end if test_gpu