Skip to content

Commit ff74af4

Browse files
implement hash for GNNGraph (#121)
* error message * == for graphs * implement hash * relax gatconv tests * fix test util
1 parent 2493acf commit ff74af4

File tree

5 files changed

+45
-7
lines changed

5 files changed

+45
-7
lines changed

src/GNNGraphs/gnngraph.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,4 +232,13 @@ Flux.Data._nobs(g::GNNGraph) = g.num_graphs
232232
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)
233233

234234
#########################
235-
Base.:(==)(g1::GNNGraph, g2::GNNGraph) = all(k -> getfield(g1,k)==getfield(g2,k), fieldnames(typeof(g1)))
235+
236+
function Base.:(==)(g1::GNNGraph, g2::GNNGraph)
237+
g1 === g2 && return true
238+
all(k -> getfield(g1, k) == getfield(g2, k), fieldnames(typeof(g1)))
239+
end
240+
241+
function Base.hash(g::T, h::UInt) where T<:GNNGraph
242+
fs = (getfield(g, k) for k in fieldnames(typeof(g)))
243+
return foldl((h, f) -> hash(f, h), fs, init=hash(T, h))
244+
end

src/GNNGraphs/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
5555
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
5656
if duplicate_if_needed
5757
# Used to copy edge features on reverse edges
58-
@assert all(s -> s == 0 || s == n || s == n÷2, sz)
58+
@assert all(s -> s == 0 || s == n || s == n÷2, sz) "Wrong size in last dimension for feature array."
5959

6060
function duplicate(v)
6161
if v isa AbstractArray && size(v)[end] == n÷2
@@ -65,7 +65,7 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
6565
end
6666
data = NamedTuple{keys(data)}(duplicate.(values(data)))
6767
else
68-
@assert all(s -> s == 0 || s == n, sz)
68+
@assert all(s -> s == 0 || s == n, sz) "Wrong size in last dimension for feature array."
6969
end
7070
return data
7171
end

test/GNNGraphs/gnngraph.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,38 @@
248248
end
249249

250250
@testset "Graphs.jl integration" begin
251-
g = GNNGraph(erdos_renyi(10, 20))
251+
g = GNNGraph(erdos_renyi(10, 20), graph_type=GRAPH_T)
252252
@test g isa Graphs.AbstractGraph
253253
end
254+
255+
@testset "==" begin
256+
g1 = rand_graph(5, 6, ndata=rand(5), edata=rand(6), graph_type=GRAPH_T)
257+
@test g1 == g1
258+
@test g1 == deepcopy(g1)
259+
@test g1 !== deepcopy(g1)
260+
261+
g2 = GNNGraph(g1, graph_type=GRAPH_T)
262+
@test g1 == g2
263+
@test g1 === g2 # this is true since GNNGraph is immutable
264+
265+
g2 = GNNGraph(g1, ndata=rand(5), graph_type=GRAPH_T)
266+
@test g1 != g2
267+
@test g1 !== g2
268+
269+
g2 = GNNGraph(g1, edata=rand(6), graph_type=GRAPH_T)
270+
@test g1 != g2
271+
@test g1 !== g2
272+
end
273+
274+
@testset "hash" begin
275+
g1 = rand_graph(5, 6, ndata=rand(5), edata=rand(6), graph_type=GRAPH_T)
276+
@test hash(g1) == hash(g1)
277+
@test hash(g1) == hash(deepcopy(g1))
278+
@test hash(g1) == hash(GNNGraph(g1, ndata=g1.ndata, graph_type=GRAPH_T))
279+
@test hash(g1) == hash(GNNGraph(g1, ndata=g1.ndata, graph_type=GRAPH_T))
280+
@test hash(g1) != hash(GNNGraph(g1, ndata=rand(5), graph_type=GRAPH_T))
281+
@test hash(g1) != hash(GNNGraph(g1, edata=rand(6), graph_type=GRAPH_T))
282+
end
254283
end
255284

256285

test/layers/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
for heads in (1, 2), concat in (true, false)
9999
l = GATConv(in_channel => out_channel; heads, concat)
100100
for g in test_graphs
101-
test_layer(l, g, rtol=1e-3, atol=1e-3,
101+
test_layer(l, g, rtol=1e-3,
102102
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
103103
end
104104
end
@@ -114,7 +114,7 @@
114114
for heads in (1, 2), concat in (true, false)
115115
l = GATv2Conv(in_channel => out_channel; heads, concat)
116116
for g in test_graphs
117-
test_layer(l, g, rtol=1e-3, atol=1e-3,
117+
test_layer(l, g, rtol=1e-3,
118118
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
119119
end
120120
end

test/test_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ function test_approx_structs(l, l̄, l̄2; atol=1e-5, rtol=1e-5,
199199
end
200200
else
201201
verbose && println("C")
202-
test_approx_structs(x, f̄, f̄2; exclude_grad_fields, broken_grad_fields, verbose)
202+
test_approx_structs(x, f̄, f̄2; atol, rtol, exclude_grad_fields, broken_grad_fields, verbose)
203203
end
204204
end
205205
return true

0 commit comments

Comments
 (0)