@@ -1274,50 +1274,80 @@ def test_inequality(self):
1274
1274
1275
1275
def generate_symbols (self ):
1276
1276
model = Model ()
1277
- d , ds = model .disjoint_lists (10 , 4 )
1277
+ d = model .disjoint_lists_symbol (10 , 4 )
1278
1278
model .lock ()
1279
1279
yield d
1280
- yield from ds
1280
+ yield from d
1281
1281
1282
1282
def test (self ):
1283
1283
model = Model ()
1284
1284
1285
- model .disjoint_lists (10 , 4 )
1285
+ dls = model .disjoint_lists_symbol (10 , 4 )
1286
+
1287
+ self .assertEqual (dls .primary_set_size (), 10 )
1288
+ self .assertEqual (dls .num_disjoint_lists (), 4 )
1289
+
1290
+ def test_deprecated_creation_method (self ):
1291
+ model = Model ()
1292
+ with self .assertWarnsRegex (
1293
+ DeprecationWarning ,
1294
+ r"The return behavior of Model.disjoint_lists\(\) is deprecated"
1295
+ ):
1296
+ d , dls = model .disjoint_lists (10 , 4 )
1297
+
1298
+ self .assertIsInstance (d , dwave .optimization .symbols .DisjointLists )
1299
+ self .assertEqual (len (dls ), 4 )
1300
+ self .assertIsInstance (dls [0 ], dwave .optimization .symbols .DisjointList )
1301
+
1302
+ def test_indexing (self ):
1303
+ model = Model ()
1304
+
1305
+ dls = model .disjoint_lists_symbol (10 , 4 )
1306
+
1307
+ self .assertEqual (len (list (dls )), 4 )
1308
+ self .assertIsInstance (dls [0 ], dwave .optimization .symbols .DisjointList )
1309
+ self .assertIsInstance (dls [3 ], dwave .optimization .symbols .DisjointList )
1310
+
1311
+ with self .assertRaises (IndexError ):
1312
+ dls [4 ]
1286
1313
1287
1314
def test_construction (self ):
1288
1315
model = Model ()
1289
1316
1290
1317
with self .assertRaises (ValueError ):
1291
- model .disjoint_lists (- 5 , 1 )
1318
+ model .disjoint_lists_symbol (- 5 , 1 )
1292
1319
with self .assertRaises (ValueError ):
1293
- model .disjoint_lists (1 , - 5 )
1320
+ model .disjoint_lists_symbol (1 , - 5 )
1294
1321
1295
1322
model .states .resize (1 )
1296
1323
1297
- ds , ( x ,) = model .disjoint_lists (0 , 1 )
1298
- self .assertEqual (x .shape (), (- 1 ,)) # todo: handle this special case
1324
+ ds = model .disjoint_lists_symbol (0 , 1 )
1325
+ self .assertEqual (ds [ 0 ] .shape (), (- 1 ,)) # todo: handle this special case
1299
1326
1300
1327
def test_num_returned_nodes (self ):
1301
1328
model = Model ()
1302
1329
1303
- d , ds = model .disjoint_lists (10 , 4 )
1330
+ model .disjoint_lists_symbol (10 , 4 )
1331
+
1332
+ # One DisjointListsNode, and one node for each of the 4 successor lists
1333
+ self .assertEqual (model .num_nodes (), 5 )
1304
1334
1305
1335
def test_set_state (self ):
1306
1336
with self .subTest ("array-like output lists" ):
1307
1337
model = Model ()
1308
1338
model .states .resize (1 )
1309
- x , ys = model .disjoint_lists (5 , 3 )
1339
+ x = model .disjoint_lists_symbol (5 , 3 )
1310
1340
model .lock ()
1311
1341
1312
1342
x .set_state (0 , [[0 , 1 ], [2 , 3 ], [4 ]])
1313
1343
1314
- np .testing .assert_array_equal (ys [0 ].state (), [0 , 1 ])
1315
- np .testing .assert_array_equal (ys [1 ].state (), [2 , 3 ])
1316
- np .testing .assert_array_equal (ys [2 ].state (), [4 ])
1344
+ np .testing .assert_array_equal (x [0 ].state (), [0 , 1 ])
1345
+ np .testing .assert_array_equal (x [1 ].state (), [2 , 3 ])
1346
+ np .testing .assert_array_equal (x [2 ].state (), [4 ])
1317
1347
1318
1348
with self .subTest ("invalid state index" ):
1319
1349
model = Model ()
1320
- x , _ = model .disjoint_lists (5 , 3 )
1350
+ x = model .disjoint_lists_symbol (5 , 3 )
1321
1351
1322
1352
state = [[0 , 1 , 2 , 3 , 4 ], [], []]
1323
1353
@@ -1338,16 +1368,16 @@ def test_set_state(self):
1338
1368
# gets translated into integer according to NumPy rules
1339
1369
model = Model ()
1340
1370
model .states .resize (1 )
1341
- x , ys = model .disjoint_lists (5 , 3 )
1371
+ x = model .disjoint_lists_symbol (5 , 3 )
1342
1372
model .lock ()
1343
1373
1344
1374
x .set_state (0 , [[4.5 , 3 , 2 , 1 , 0 ], [], []])
1345
- np .testing .assert_array_equal (ys [0 ].state (), [4 , 3 , 2 , 1 , 0 ])
1375
+ np .testing .assert_array_equal (x [0 ].state (), [4 , 3 , 2 , 1 , 0 ])
1346
1376
1347
1377
with self .subTest ("invalid" ):
1348
1378
model = Model ()
1349
1379
model .states .resize (1 )
1350
- x , ys = model .disjoint_lists (5 , 3 )
1380
+ x = model .disjoint_lists_symbol (5 , 3 )
1351
1381
model .lock ()
1352
1382
1353
1383
with self .assertRaisesRegex (
@@ -1376,10 +1406,10 @@ def test_set_state(self):
1376
1406
def test_state_size (self ):
1377
1407
model = Model ()
1378
1408
1379
- d , ds = model .disjoint_lists (10 , 4 )
1409
+ d = model .disjoint_lists_symbol (10 , 4 )
1380
1410
1381
1411
self .assertEqual (d .state_size (), 0 )
1382
- for s in ds :
1412
+ for s in d :
1383
1413
self .assertEqual (s .state_size (), 10 * 8 )
1384
1414
1385
1415
0 commit comments