Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ JSON3 = "1"
LazyModules = "0.3"
MAT = "0.10, 0.11"
MLUtils = "0.2.0, 0.3, 0.4"
NPZ = "0.4.1"
NPZ = "0.4"
SparseArrays = "1.0"
Pickle = "0.3"
Requires = "1"
Statistics = "1"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/datasets/graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,6 @@ Reddit
TemporalBrains
TUDataset
WindMillEnergy
AmazonComputers
AmazonPhoto
```
5 changes: 4 additions & 1 deletion src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ export TemporalBrains
include("datasets/graphs/windmillenergy.jl")
export WindMillEnergy

include("datasets/graphs/amazon.jl")
export AmazonComputers
export AmazonPhoto
# Meshes

include("datasets/meshes/faust.jl")
Expand All @@ -168,7 +171,7 @@ function __init__()
__init__pemsbay()
__init__temporalbrains()
__init__windmillenergy()

__init__amazon()
# misc
__init__iris()
__init__mutagenesis()
Expand Down
174 changes: 174 additions & 0 deletions src/datasets/graphs/amazon.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
using DataDeps
using NPZ
using SparseArrays

function __init__amazon()

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
DEPNAME = "Amazon"
LINK = "https://github.com/shchur/gnn-benchmark/raw/master/data/npz"
DOCS = "https://github.com/shchur/gnn-benchmark"

DATA = [
"amazon_electronics_computers.npz",
"amazon_electronics_photo.npz"
]

register(DataDep(
DEPNAME,
"""
Dataset: Amazon Co-purchase Network
Website: $DOCS
""",
map(x -> "$LINK/$x", DATA),
"ceda5f611b7aa82c9e81eea9625526f2d7560e72e5c24dc010082915e52fb8b6"
))
end


function read_amazon_data(file; dir=nothing)

path = isnothing(dir) ? datadep"Amazon" : dir
full_path = joinpath(path, file)

vars = ["attr_data","attr_indices","attr_indptr","attr_shape",
"labels","adj_indices","adj_indptr","adj_shape"]

data = NPZ.npzread(full_path, vars)

attr_shape = data["attr_shape"]

x_sparse = SparseMatrixCSC(
attr_shape[2],
attr_shape[1],
data["attr_indptr"] .+ 1,
data["attr_indices"] .+ 1,
data["attr_data"]
)

x = Float32.(Matrix(x_sparse))

y = vec(Int.(data["labels"])) .+ 1

indices = data["adj_indices"]
indptr = data["adj_indptr"]

src = Int[]
dst = Int[]

for i in 1:(length(indptr)-1)
for j in (indptr[i]+1):indptr[i+1]

u = i
v = Int(indices[j]) + 1

push!(src, u)
push!(dst, v)

push!(src, v)
push!(dst, u)

end
end

edge_tuples = unique(zip(src, dst))
edge_tuples = filter(e -> e[1] != e[2], edge_tuples)

src = [e[1] for e in edge_tuples]
dst = [e[2] for e in edge_tuples]

num_nodes = attr_shape[1]

node_data = (
features = x,
targets = y
)

metadata = Dict(
"name" => file,
"num_classes" => length(unique(y))
)

g = Graph(
num_nodes=num_nodes,
edge_index=(src, dst),
node_data=node_data
)

return metadata, g

end


"""
AmazonComputers(; dir=nothing)

The Amazon Computers co-purchase network dataset from the
"Pitfalls of Graph Neural Network Evaluation" paper.

Nodes represent products and edges represent products frequently
bought together. Features are bag-of-words product reviews.

Statistics
- Nodes: 13752
- Edges: 491722
- Features: 767
"""
struct AmazonComputers <: AbstractDataset
metadata::Dict{String,Any}
graphs::Vector{Graph}
end


function AmazonComputers(; dir=nothing)

metadata, g = read_amazon_data(
"amazon_electronics_computers.npz",
dir=dir
)

AmazonComputers(metadata, [g])

end


Base.length(d::AmazonComputers) = length(d.graphs)
Base.getindex(d::AmazonComputers, ::Colon) = d.graphs[1]
Base.getindex(d::AmazonComputers, i) = d.graphs[i]


"""
AmazonPhoto(; dir=nothing)

The Amazon Photo co-purchase network dataset from the
"Pitfalls of Graph Neural Network Evaluation" paper.

Nodes represent photo-related products and edges represent
products frequently bought together.

Statistics
- Nodes: 7650
- Edges: 238162
- Features: 745
"""
struct AmazonPhoto <: AbstractDataset
metadata::Dict{String,Any}
graphs::Vector{Graph}
end


function AmazonPhoto(; dir=nothing)

metadata, g = read_amazon_data(
"amazon_electronics_photo.npz",
dir=dir
)

AmazonPhoto(metadata, [g])

end


Base.length(d::AmazonPhoto) = length(d.graphs)
Base.getindex(d::AmazonPhoto, ::Colon) = d.graphs[1]
Base.getindex(d::AmazonPhoto, i) = d.graphs[i]
51 changes: 51 additions & 0 deletions test/datasets/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,58 @@ end
@test maximum(a) == g.num_nodes
end
end
@testset "AmazonComputers" begin
data = AmazonComputers()
@test data isa AbstractDataset
@test length(data) == 1

g = data[1]
@test g === data[:]
@test g isa MLDatasets.Graph

@test g.num_nodes == 13752
@test g.num_edges == 491722

@test size(g.node_data.features) == (767, g.num_nodes)
@test size(g.node_data.targets) == (g.num_nodes,)

@test g.edge_index isa Tuple{Vector{Int}, Vector{Int}}

s, t = g.edge_index
for a in (s, t)
@test a isa Vector{Int}
@test length(a) == g.num_edges
@test minimum(a) == 1
@test maximum(a) == g.num_nodes
end
end


@testset "AmazonPhoto" begin
data = AmazonPhoto()
@test data isa AbstractDataset
@test length(data) == 1

g = data[1]
@test g === data[:]
@test g isa MLDatasets.Graph

@test g.num_nodes == 7650
@test g.num_edges == 238162

@test size(g.node_data.features) == (745, g.num_nodes)
@test size(g.node_data.targets) == (g.num_nodes,)

@test g.edge_index isa Tuple{Vector{Int}, Vector{Int}}

s, t = g.edge_index
for a in (s, t)
@test a isa Vector{Int}
@test length(a) == g.num_edges
@test minimum(a) == 1
@test maximum(a) == g.num_nodes
end
end
# maybe, maybe, maybe??
Sys.iswindows() || @testset "OGBn-mag" begin
data = OGBDataset("ogbn-mag")
Expand Down
Loading