Skip to content
6 changes: 6 additions & 0 deletions _nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"descendants",
"descendants_at_distance",
"diamond_graph",
"digraph__new__",
"dijkstra_path",
"dijkstra_path_length",
"dodecahedral_graph",
Expand All @@ -98,6 +99,7 @@
"from_scipy_sparse_array",
"frucht_graph",
"generic_bfs_edges",
"graph__new__",
"has_path",
"heawood_graph",
"hits",
Expand Down Expand Up @@ -126,6 +128,8 @@
"louvain_communities",
"lowest_common_ancestor",
"moebius_kantor_graph",
"multidigraph__new__",
"multigraph__new__",
"node_connected_component",
"null_graph",
"number_connected_components",
Expand Down Expand Up @@ -360,6 +364,8 @@ def update_env_var(varname):
update_env_var("NETWORKX_AUTOMATIC_BACKENDS") # For NetworkX 3.2
# Automatically create nx-cugraph Graph from graph generators
update_env_var("NETWORKX_BACKEND_PRIORITY_GENERATORS")
# And for graph classes such as `nx.Graph()` for NetworkX >=3.6
update_env_var("NETWORKX_BACKEND_PRIORITY_CLASSES")
# Run default NetworkX implementation (in >=3.4) if not implemented by nx-cugraph
if (varname := "NETWORKX_FALLBACK_TO_NX") not in os.environ:
os.environ[varname] = "true"
Expand Down
16 changes: 15 additions & 1 deletion nx_cugraph/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import nx_cugraph as nxcg

from ..utils import index_dtype
from ..utils import index_dtype, networkx_algorithm
from .graph import CudaGraph, Graph, _GraphCache

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -310,3 +310,17 @@ def _out_degrees_array(self, *, ignore_selfloops=False):
if src_indices.size == 0:
return cp.zeros(self._N, dtype=np.int64)
return cp.bincount(src_indices, minlength=self._N)


@networkx_algorithm(version_added="25.12")
def digraph__new__(cls, incoming_graph_data=None, **attr):
if nx.config.backends.cugraph.use_compat_graphs:
return object.__new__(DiGraph)
return CudaDiGraph(incoming_graph_data=incoming_graph_data, **attr)


@digraph__new__._can_run
def _(cls, incoming_graph_data=None, **attr):
if cls is not nx.DiGraph:
return "Unknown subclasses of nx.DiGraph are not supported."
return True
20 changes: 19 additions & 1 deletion nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import nx_cugraph as nxcg
from nx_cugraph import _nxver

from ..utils import index_dtype
from ..utils import index_dtype, networkx_algorithm

if TYPE_CHECKING: # pragma: no cover
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -1302,3 +1302,21 @@ def _subgraph_weights(
edge_array = edge_array[mask]

return edge_array


@networkx_algorithm(version_added="25.12")
def graph__new__(cls, incoming_graph_data=None, **attr):
# Dispatched from `nx.Graph.__new__`. See details of `object.__new__` behavior here:
# https://docs.python.org/3/reference/datamodel.html#object.__new__
if nx.config.backends.cugraph.use_compat_graphs:
# Because `issubclass(Graph, nx.Graph)`, Graph.__init__ will be called next
return object.__new__(Graph)
# Because `not issubclass(CudaGraph, nx.Graph)`, CudaGraph.__init__ WON'T be called
return CudaGraph(incoming_graph_data=incoming_graph_data, **attr)


@graph__new__._can_run
def _(cls, incoming_graph_data=None, **attr):
if cls is not nx.Graph:
return "Unknown subclasses of nx.Graph are not supported."
return True
19 changes: 19 additions & 0 deletions nx_cugraph/classes/multidigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import nx_cugraph as nxcg

from ..utils import networkx_algorithm
from .digraph import CudaDiGraph, DiGraph
from .graph import Graph, _GraphCache
from .multigraph import CudaMultiGraph, MultiGraph
Expand Down Expand Up @@ -93,3 +94,21 @@ def _to_compat_graph_class(cls) -> type[MultiDiGraph]:
@networkx_api
def to_undirected(self, reciprocal=False, as_view=False):
raise NotImplementedError


@networkx_algorithm(version_added="25.12")
def multidigraph__new__(cls, incoming_graph_data=None, multigraph_input=None, **attr):
if nx.config.backends.cugraph.use_compat_graphs:
return object.__new__(MultiDiGraph)
return CudaMultiDiGraph(
incoming_graph_data=incoming_graph_data,
multigraph_input=multigraph_input,
**attr,
)


@multidigraph__new__._can_run
def _(cls, incoming_graph_data=None, multigraph_input=None, **attr):
if cls is not nx.MultiDiGraph:
return "Unknown subclasses of nx.MultiDiGraph are not supported."
return True
22 changes: 20 additions & 2 deletions nx_cugraph/classes/multigraph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -21,7 +21,7 @@

import nx_cugraph as nxcg

from ..utils import index_dtype
from ..utils import index_dtype, networkx_algorithm
from .graph import CudaGraph, Graph, _GraphCache

if TYPE_CHECKING:
Expand Down Expand Up @@ -580,3 +580,21 @@ def _sort_edge_indices(self, primary="src"):
if self.edge_keys is not None:
edge_keys = self.edge_keys
self.edge_keys = [edge_keys[i] for i in indices.tolist()]


@networkx_algorithm(version_added="25.12")
def multigraph__new__(cls, incoming_graph_data=None, multigraph_input=None, **attr):
if nx.config.backends.cugraph.use_compat_graphs:
return object.__new__(MultiGraph)
return CudaMultiGraph(
incoming_graph_data=incoming_graph_data,
multigraph_input=multigraph_input,
**attr,
)


@multigraph__new__._can_run
def _(cls, incoming_graph_data=None, multigraph_input=None, **attr):
if cls is not nx.MultiGraph:
return "Unknown subclasses of nx.MultiGraph are not supported."
return True
5 changes: 3 additions & 2 deletions nx_cugraph/scripts/print_tree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -143,7 +143,8 @@ def create_tree(
incomplete=incomplete,
different=different,
)
assoc_in(tree, path.split("."), payload)
if payload is not None:
assoc_in(tree, path.split("."), payload)
return tree


Expand Down
61 changes: 61 additions & 0 deletions nx_cugraph/tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import networkx as nx
import pytest

import nx_cugraph as nxcg
from nx_cugraph import _nxver
from nx_cugraph.classes.graph import _GraphCache


Expand Down Expand Up @@ -80,6 +82,65 @@ def test_class_to_class():
assert cls.is_multigraph() == G.is_multigraph() == val.is_multigraph()


@pytest.mark.parametrize(
"nxcg_class", [nxcg.Graph, nxcg.DiGraph, nxcg.MultiGraph, nxcg.MultiDiGraph]
)
@pytest.mark.parametrize("use_compat_graphs", [True, False])
def test_dispatch_graph_classes(nxcg_class, use_compat_graphs):
if _nxver < (3, 6):
pytest.skip(reason="Dispatching graph classes requires nx >=3.6")
nx_class = nxcg_class.to_networkx_class()
assert nx_class is not nxcg_class
cuda_class = nxcg_class.to_cudagraph_class()
assert cuda_class is not nxcg_class

expected_nxcg_class = nxcg_class if use_compat_graphs else cuda_class

class NxGraphSubclass(nx_class):
pass

class NxcgGraphSubclass(nxcg_class):
pass

with (
nx.config.backend_priority(classes=[]),
nx.config.backends.cugraph(use_compat_graphs=use_compat_graphs),
):
G = nx_class()
assert type(G) is nx_class
G = nx_class(backend="cugraph")
assert type(G) is expected_nxcg_class
G = NxGraphSubclass()
assert type(G) is NxGraphSubclass
with pytest.raises(NotImplementedError, match="not implemented by 'cugraph'"):
# can_run is False for unknown subclasses
NxGraphSubclass(backend="cugraph")
G = NxcgGraphSubclass()
assert type(G) is NxcgGraphSubclass
with pytest.raises(NotImplementedError, match="not implemented by 'cugraph'"):
NxcgGraphSubclass(backend="cugraph")

with (
nx.config.backend_priority(classes=["cugraph"]),
nx.config.backends.cugraph(use_compat_graphs=use_compat_graphs),
):
G = nx_class()
assert type(G) is expected_nxcg_class
G = nx_class(backend="networkx")
assert type(G) is nx_class

# can_run is False for unknown subclasses
G = NxGraphSubclass()
assert type(G) is NxGraphSubclass
G = NxGraphSubclass(backend="networkx")
assert type(G) is NxGraphSubclass

G = NxcgGraphSubclass()
assert type(G) is NxcgGraphSubclass
G = NxcgGraphSubclass(backend="networkx")
assert type(G) is NxcgGraphSubclass # Perhaps odd, but the correct behavior


@pytest.mark.parametrize(
"graph_class", [nxcg.Graph, nxcg.DiGraph, nxcg.MultiGraph, nxcg.MultiDiGraph]
)
Expand Down
3 changes: 3 additions & 0 deletions scripts/update_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def main(readme_file, objects_filename):
def get_payload(info, **kwargs):
path = "networkx." + info.networkx_path
subpath, name = path.rsplit(".", 1)
if "__" in name:
# Don't include e.g. Graph.__new__
return None
# Many objects are referred to in modules above where they are defined.
while True:
path = f"{subpath}.{name}"
Expand Down
Loading