Skip to content

Commit 7534f09

Browse files
authored
Merge pull request #397 from wbernoudy/feature/improve-disjoint-lists-interface
QOL improvements for DisjointLists
2 parents 9eb2055 + 89d750d commit 7534f09

File tree

7 files changed

+149
-45
lines changed

7 files changed

+149
-45
lines changed

dwave/optimization/_model.pyx

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,10 +1144,9 @@ cdef class Symbol:
11441144
11451145
>>> from dwave.optimization.model import Model
11461146
>>> model = Model()
1147-
>>> lsymbol, lsymbol_lists = model.disjoint_lists(
1148-
... primary_set_size=5,
1149-
... num_disjoint_lists=2)
1150-
>>> lsymbol_lists[0].equals(next(lsymbol.iter_successors()))
1147+
>>> x = model.binary()
1148+
>>> y = x + 5
1149+
>>> y.equals(next(x.iter_successors()))
11511150
True
11521151
"""
11531152
cdef vector[cppNode.SuccessorView].const_iterator it = self.node_ptr.successors().begin()
@@ -1233,17 +1232,17 @@ cdef class Symbol:
12331232
12341233
>>> from dwave.optimization import Model
12351234
>>> model = Model()
1236-
>>> lsymbol, lsymbol_lists = model.disjoint_lists(primary_set_size=5, num_disjoint_lists=2)
1235+
>>> lsymbol = model.disjoint_lists_symbol(primary_set_size=5, num_disjoint_lists=2)
12371236
>>> with model.lock():
12381237
... model.states.resize(2)
12391238
... lsymbol.set_state(0, [[0, 4], [1, 2, 3]])
12401239
... lsymbol.set_state(1, [[3, 4], [0, 1, 2]])
1241-
... print(f"state 0: {lsymbol_lists[0].state(0)} and {lsymbol_lists[1].state(0)}")
1242-
... print(f"state 1: {lsymbol_lists[0].state(1)} and {lsymbol_lists[1].state(1)}")
1240+
... print(f"state 0: {lsymbol[0].state(0)} and {lsymbol[1].state(0)}")
1241+
... print(f"state 1: {lsymbol[0].state(1)} and {lsymbol[1].state(1)}")
12431242
... lsymbol.reset_state(0)
12441243
... print("After reset:")
1245-
... print(f"state 0: {lsymbol_lists[0].state(0)} and {lsymbol_lists[1].state(0)}")
1246-
... print(f"state 1: {lsymbol_lists[0].state(1)} and {lsymbol_lists[1].state(1)}")
1244+
... print(f"state 0: {lsymbol[0].state(0)} and {lsymbol[1].state(0)}")
1245+
... print(f"state 1: {lsymbol[0].state(1)} and {lsymbol[1].state(1)}")
12471246
state 0: [0. 4.] and [1. 2. 3.]
12481247
state 1: [3. 4.] and [0. 1. 2.]
12491248
After reset:

dwave/optimization/generators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def capacitated_vehicle_routing(demand: numpy.typing.ArrayLike,
267267
A model encoding the CVRP problem.
268268
269269
Notes:
270-
The model uses a :class:`~dwave.optimization.model.Model.disjoint_lists`
270+
The model uses a :class:`~dwave.optimization.symbols.DisjointLists`
271271
class as the decision variable being optimized, with permutations of its
272272
sublist representing various itineraries for each vehicle.
273273
"""
@@ -381,7 +381,7 @@ class as the decision variable being optimized, with permutations of its
381381
capacity = model.constant(vehicle_capacity)
382382

383383
# Add the decision variable
384-
routes_decision, routes = model.disjoint_lists(
384+
routes = model.disjoint_lists_symbol(
385385
primary_set_size=num_customers,
386386
num_disjoint_lists=number_of_vehicles)
387387

@@ -455,7 +455,7 @@ def capacitated_vehicle_routing_with_time_windows(demand: numpy.typing.ArrayLike
455455
A model encoding the CVRPTW problem.
456456
457457
Notes:
458-
The model uses a :class:`~dwave.optimization.model.Model.disjoint_lists`
458+
The model uses a :class:`~dwave.optimization.symbols.DisjointLists`
459459
class as the decision variable being optimized, with permutations of its
460460
sublist representing various itineraries for each vehicle.
461461
"""
@@ -561,7 +561,7 @@ class as the decision variable being optimized, with permutations of its
561561
one = model.constant(1)
562562

563563
# Add the decision variable
564-
routes_decision, routes = model.disjoint_lists(
564+
routes = model.disjoint_lists_symbol(
565565
primary_set_size=num_customers,
566566
num_disjoint_lists=number_of_vehicles)
567567

dwave/optimization/model.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import numpy as np
3333
import tempfile
3434
import typing
35+
import warnings
3536

3637
from dwave.optimization._model import ArraySymbol, _Graph, Symbol
3738
from dwave.optimization.states import States
@@ -331,11 +332,65 @@ def disjoint_lists(
331332
>>> from dwave.optimization.model import Model
332333
>>> model = Model()
333334
>>> destinations, routes = model.disjoint_lists(10, 4)
335+
336+
.. deprecated:: 0.6.7
337+
338+
The return behavior of this method will be changed in
339+
dwave.optimization 0.8.0. Use :meth:`.disjoint_lists_symbol`.
340+
"""
341+
342+
warnings.warn(
343+
"The return behavior of Model.disjoint_lists() is deprecated "
344+
"since dwave.optimization 0.6.7 and will be changed to the "
345+
"behavior of Model.disjoint_lists_symbol() in 0.8.0. Use "
346+
"Model.disjoint_lists_symbol().",
347+
DeprecationWarning,
348+
)
349+
350+
disjoint_lists = self.disjoint_lists_symbol(
351+
primary_set_size, num_disjoint_lists
352+
)
353+
return disjoint_lists, list(disjoint_lists)
354+
355+
def disjoint_lists_symbol(
356+
self,
357+
primary_set_size: int,
358+
num_disjoint_lists: int,
359+
) -> DisjointLists:
360+
"""Create a disjoint-lists symbol as a decision variable.
361+
362+
Divides a set of the elements of ``range(primary_set_size)`` into
363+
``num_disjoint_lists`` ordered partitions.
364+
365+
Args:
366+
primary_set_size: Number of elements in the primary set to
367+
be partitioned into disjoint lists.
368+
num_disjoint_lists: Number of disjoint lists.
369+
370+
Returns:
371+
A disjoint-lists symbol.
372+
373+
Examples:
374+
This example creates a symbol of 10 elements that is divided
375+
into 4 lists.
376+
377+
>>> from dwave.optimization.model import Model
378+
>>> model = Model()
379+
>>> disjoint_lists = model.disjoint_lists_symbol(10, 4)
380+
>>> disjoint_lists.primary_set_size()
381+
10
382+
>>> disjoint_lists.num_disjoint_lists()
383+
4
334384
"""
335385
from dwave.optimization.symbols import DisjointLists, DisjointList # avoid circular import
336-
main = DisjointLists(self, primary_set_size, num_disjoint_lists)
337-
lists = [DisjointList(main, i) for i in range(num_disjoint_lists)]
338-
return main, lists
386+
disjoint_lists = DisjointLists(self, primary_set_size, num_disjoint_lists)
387+
388+
# create the DisjointList symbols, which will create the successor nodes, even
389+
# though we won't use them directly here
390+
for i in range(num_disjoint_lists):
391+
DisjointList(disjoint_lists, i)
392+
393+
return disjoint_lists
339394

340395
def feasible(self, index: int = 0) -> bool:
341396
"""Check the feasibility of the state at the input index.

dwave/optimization/symbols/collections.pyx

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ cdef class DisjointList(ArraySymbol):
305305
"""
306306
def __init__(self, DisjointLists parent, Py_ssize_t list_index):
307307
if list_index < 0 or list_index >= parent.num_disjoint_lists():
308-
raise ValueError(
308+
raise IndexError(
309309
"`list_index` must be less than the number of disjoint sets of the parent"
310310
)
311311

@@ -411,6 +411,9 @@ cdef class DisjointLists(Symbol):
411411

412412
self.initialize_node(model, self.ptr)
413413

414+
def __getitem__(self, index: int):
415+
return DisjointList(self, index)
416+
414417
@classmethod
415418
def _from_symbol(cls, Symbol symbol):
416419
cdef DisjointListsNode* ptr = dynamic_cast_ptr[DisjointListsNode](symbol.node_ptr)
@@ -491,21 +494,21 @@ cdef class DisjointLists(Symbol):
491494
Index of the state to set
492495
state:
493496
Assignment of values for the state.
494-
497+
495498
Examples:
496499
This example sets the state of a disjoint-lists symbol. You can
497500
inspect the state of each list individually.
498-
501+
499502
>>> from dwave.optimization.model import Model
500503
>>> model = Model()
501-
>>> lists_symbol, lists_array = model.disjoint_lists(
504+
>>> lists_symbol = model.disjoint_lists_symbol(
502505
... primary_set_size=5,
503506
... num_disjoint_lists=3
504507
... )
505508
>>> with model.lock():
506509
... model.states.resize(1)
507510
... lists_symbol.set_state(0, [[0, 1, 2, 3], [4], []])
508-
... for index, disjoint_list in enumerate(lists_array):
511+
... for index, disjoint_list in enumerate(lists_symbol):
509512
... print(f"DisjointList {index}:")
510513
... print(disjoint_list.state(0))
511514
DisjointList 0:
@@ -569,6 +572,10 @@ cdef class DisjointLists(Symbol):
569572
"""Return the number of disjoint lists in the symbol."""
570573
return self.ptr.num_disjoint_lists()
571574

575+
def primary_set_size(self):
576+
"""Return the size of primary set of elements that the lists contain."""
577+
return self.ptr.primary_set_size()
578+
572579
# An observing pointer to the C++ DisjointListsNode
573580
cdef DisjointListsNode* ptr
574581

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
---
2+
features:
3+
- |
4+
``DisjointLists`` symbols are now indexable, returning the corresponding
5+
``DisjointList`` (singular) symbol for the given index. This allows the
6+
method of creating disjoint lists on the model to be simplified, returning
7+
only a ``DisjointLists`` symbol instead of a tuple of the symbol and a
8+
list of ``DisjointList`` symbols. A new method
9+
``Model.disjoint_lists_symbol()`` has been added to the Model class which
10+
implements this.
11+
deprecations:
12+
- |
13+
The ``Model.disjoint_lists()`` method has been deprecated. Use
14+
``Model.disjoint_lists_symbol()`` instead.

tests/test_model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -433,19 +433,18 @@ def test_remove_unused_symbols(self):
433433
with self.subTest("disjoint lists"):
434434
model = Model()
435435

436-
base, lists = model.disjoint_lists(10, 4)
436+
disjoint_lists = model.disjoint_lists_symbol(10, 4)
437437

438438
# only use some of the lists
439-
model.minimize(lists[0].sum())
440-
model.add_constraint(lists[1].sum() <= model.constant(3))
439+
model.minimize(disjoint_lists[0].sum())
440+
model.add_constraint(disjoint_lists[1].sum() <= model.constant(3))
441441

442-
lists[2].prod() # this one will hopefully be removed
442+
disjoint_lists[2].prod() # this one will hopefully be removed
443443

444444
self.assertEqual(model.num_symbols(), 10)
445445

446446
# make sure they aren't being kept alive by other objects
447-
del lists
448-
del base
447+
del disjoint_lists
449448

450449
num_removed = model.remove_unused_symbols()
451450

tests/test_symbols.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,50 +1274,80 @@ def test_inequality(self):
12741274

12751275
def generate_symbols(self):
12761276
model = Model()
1277-
d, ds = model.disjoint_lists(10, 4)
1277+
d = model.disjoint_lists_symbol(10, 4)
12781278
model.lock()
12791279
yield d
1280-
yield from ds
1280+
yield from d
12811281

12821282
def test(self):
12831283
model = Model()
12841284

1285-
model.disjoint_lists(10, 4)
1285+
dls = model.disjoint_lists_symbol(10, 4)
1286+
1287+
self.assertEqual(dls.primary_set_size(), 10)
1288+
self.assertEqual(dls.num_disjoint_lists(), 4)
1289+
1290+
def test_deprecated_creation_method(self):
1291+
model = Model()
1292+
with self.assertWarnsRegex(
1293+
DeprecationWarning,
1294+
r"The return behavior of Model.disjoint_lists\(\) is deprecated"
1295+
):
1296+
d, dls = model.disjoint_lists(10, 4)
1297+
1298+
self.assertIsInstance(d, dwave.optimization.symbols.DisjointLists)
1299+
self.assertEqual(len(dls), 4)
1300+
self.assertIsInstance(dls[0], dwave.optimization.symbols.DisjointList)
1301+
1302+
def test_indexing(self):
1303+
model = Model()
1304+
1305+
dls = model.disjoint_lists_symbol(10, 4)
1306+
1307+
self.assertEqual(len(list(dls)), 4)
1308+
self.assertIsInstance(dls[0], dwave.optimization.symbols.DisjointList)
1309+
self.assertIsInstance(dls[3], dwave.optimization.symbols.DisjointList)
1310+
1311+
with self.assertRaises(IndexError):
1312+
dls[4]
12861313

12871314
def test_construction(self):
12881315
model = Model()
12891316

12901317
with self.assertRaises(ValueError):
1291-
model.disjoint_lists(-5, 1)
1318+
model.disjoint_lists_symbol(-5, 1)
12921319
with self.assertRaises(ValueError):
1293-
model.disjoint_lists(1, -5)
1320+
model.disjoint_lists_symbol(1, -5)
12941321

12951322
model.states.resize(1)
12961323

1297-
ds, (x,) = model.disjoint_lists(0, 1)
1298-
self.assertEqual(x.shape(), (-1,)) # todo: handle this special case
1324+
ds = model.disjoint_lists_symbol(0, 1)
1325+
self.assertEqual(ds[0].shape(), (-1,)) # todo: handle this special case
12991326

13001327
def test_num_returned_nodes(self):
13011328
model = Model()
13021329

1303-
d, ds = model.disjoint_lists(10, 4)
1330+
model.disjoint_lists_symbol(10, 4)
1331+
1332+
# One DisjointListsNode, and one node for each of the 4 successor lists
1333+
self.assertEqual(model.num_nodes(), 5)
13041334

13051335
def test_set_state(self):
13061336
with self.subTest("array-like output lists"):
13071337
model = Model()
13081338
model.states.resize(1)
1309-
x, ys = model.disjoint_lists(5, 3)
1339+
x = model.disjoint_lists_symbol(5, 3)
13101340
model.lock()
13111341

13121342
x.set_state(0, [[0, 1], [2, 3], [4]])
13131343

1314-
np.testing.assert_array_equal(ys[0].state(), [0, 1])
1315-
np.testing.assert_array_equal(ys[1].state(), [2, 3])
1316-
np.testing.assert_array_equal(ys[2].state(), [4])
1344+
np.testing.assert_array_equal(x[0].state(), [0, 1])
1345+
np.testing.assert_array_equal(x[1].state(), [2, 3])
1346+
np.testing.assert_array_equal(x[2].state(), [4])
13171347

13181348
with self.subTest("invalid state index"):
13191349
model = Model()
1320-
x, _ = model.disjoint_lists(5, 3)
1350+
x = model.disjoint_lists_symbol(5, 3)
13211351

13221352
state = [[0, 1, 2, 3, 4], [], []]
13231353

@@ -1338,16 +1368,16 @@ def test_set_state(self):
13381368
# gets translated into integer according to NumPy rules
13391369
model = Model()
13401370
model.states.resize(1)
1341-
x, ys = model.disjoint_lists(5, 3)
1371+
x = model.disjoint_lists_symbol(5, 3)
13421372
model.lock()
13431373

13441374
x.set_state(0, [[4.5, 3, 2, 1, 0], [], []])
1345-
np.testing.assert_array_equal(ys[0].state(), [4, 3, 2, 1, 0])
1375+
np.testing.assert_array_equal(x[0].state(), [4, 3, 2, 1, 0])
13461376

13471377
with self.subTest("invalid"):
13481378
model = Model()
13491379
model.states.resize(1)
1350-
x, ys = model.disjoint_lists(5, 3)
1380+
x = model.disjoint_lists_symbol(5, 3)
13511381
model.lock()
13521382

13531383
with self.assertRaisesRegex(
@@ -1376,10 +1406,10 @@ def test_set_state(self):
13761406
def test_state_size(self):
13771407
model = Model()
13781408

1379-
d, ds = model.disjoint_lists(10, 4)
1409+
d = model.disjoint_lists_symbol(10, 4)
13801410

13811411
self.assertEqual(d.state_size(), 0)
1382-
for s in ds:
1412+
for s in d:
13831413
self.assertEqual(s.state_size(), 10 * 8)
13841414

13851415

0 commit comments

Comments
 (0)