40
40
41
41
from pytensor import tensor as pt
42
42
from pytensor .graph import RewriteDatabaseQuery
43
+ from pytensor .tensor .random .type import random_generator_type
43
44
from scipy import stats as st
44
45
45
46
from pymc .logprob .basic import conditional_logp , logp
@@ -352,7 +353,7 @@ def test_measurable_dimshuffle(ds_order, multivariate):
352
353
np .testing .assert_array_equal (ref_logp_fn (base_test_value ), ds_logp_fn (ds_test_value ))
353
354
354
355
355
- def test_unmeargeable_dimshuffles ():
356
+ def test_unmeasurable_dimshuffles ():
356
357
# Test that graphs with DimShuffles that cannot be lifted/merged fail
357
358
358
359
# Initial support axis is at axis=-1
@@ -372,3 +373,155 @@ def test_unmeargeable_dimshuffles():
372
373
# TODO: Check that logp is correct if this type of graphs is ever supported
373
374
with pytest .raises (RuntimeError , match = "could not be derived" ):
374
375
conditional_logp ({w : w_vv })
376
+
377
+
378
+ class TestMeasurableSplit :
379
+ def test_univariate (self ):
380
+ rng = np .random .default_rng (388 )
381
+ mu = np .arange (6 )[:, None ]
382
+ sigma = np .arange (5 ) + 1
383
+
384
+ x = pt .random .normal (mu , sigma , size = (6 , 5 ), name = "x" )
385
+
386
+ # axis=0
387
+ x_parts = pt .split (x , splits_size = [2 , 4 ], n_splits = 2 , axis = 0 )
388
+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
389
+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
390
+
391
+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
392
+ x_parts_test = [rng .normal (size = x_part .type .shape ) for x_part in x_parts_vv ]
393
+ logp_x1_eval , logp_x2_eval = logp_fn (* x_parts_test )
394
+ np .testing .assert_allclose (
395
+ logp_x1_eval ,
396
+ st .norm .logpdf (x_parts_test [0 ], mu [:2 ], sigma ),
397
+ )
398
+ np .testing .assert_allclose (
399
+ logp_x2_eval ,
400
+ st .norm .logpdf (x_parts_test [1 ], mu [2 :], sigma ),
401
+ )
402
+
403
+ # axis=1
404
+ x_parts = pt .split (x , splits_size = [2 , 1 , 2 ], n_splits = 3 , axis = 1 )
405
+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
406
+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
407
+
408
+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
409
+ x_parts_test = [rng .normal (size = x_part .type .shape ) for x_part in x_parts_vv ]
410
+ logp_x1_eval , logp_x2_eval , logp_x3_eval = logp_fn (* x_parts_test )
411
+ np .testing .assert_allclose (
412
+ logp_x1_eval ,
413
+ st .norm .logpdf (x_parts_test [0 ], mu , sigma [:2 ]),
414
+ )
415
+ np .testing .assert_allclose (
416
+ logp_x2_eval ,
417
+ st .norm .logpdf (x_parts_test [1 ], mu , sigma [2 :3 ]),
418
+ )
419
+ np .testing .assert_allclose (
420
+ logp_x3_eval ,
421
+ st .norm .logpdf (x_parts_test [2 ], mu , sigma [3 :]),
422
+ )
423
+
424
+ def test_multivariate (self ):
425
+ @np .vectorize (signature = ("(n),(n)->()" ))
426
+ def scipy_dirichlet_logpdf (x , alpha ):
427
+ """Compute the logpdf of a Dirichlet distribution using scipy."""
428
+ return st .dirichlet .logpdf (x , alpha )
429
+
430
+ # (3, 5) Dirichlet
431
+ rng = np .random .default_rng (426 )
432
+ rng_pt = random_generator_type ("rng" )
433
+ alpha = np .linspace (1 , 10 , 5 ) * np .array ([1 , 10 , 100 ])[:, None ]
434
+ x = pt .random .dirichlet (alpha , rng = rng_pt )
435
+
436
+ # axis=-2 (i.e., 0, - batch dimension)
437
+ x_parts = pt .split (x , splits_size = [2 , 1 ], n_splits = 2 , axis = - 2 )
438
+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
439
+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
440
+ assert logp_parts [0 ].type .shape == (2 ,)
441
+ assert logp_parts [1 ].type .shape == (1 ,)
442
+
443
+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
444
+ x_parts_test = pytensor .function ([rng_pt ], x_parts )(rng )
445
+ logp_x1_eval , logp_x2_eval = logp_fn (* x_parts_test )
446
+ np .testing .assert_allclose (
447
+ logp_x1_eval ,
448
+ scipy_dirichlet_logpdf (x_parts_test [0 ], alpha [:2 ]),
449
+ )
450
+ np .testing .assert_allclose (
451
+ logp_x2_eval ,
452
+ scipy_dirichlet_logpdf (x_parts_test [1 ], alpha [2 :]),
453
+ )
454
+
455
+ # axis=-1 (i.e., 1, - support dimension)
456
+ x_parts = pt .split (x , splits_size = [2 , 3 ], n_splits = 2 , axis = - 1 )
457
+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
458
+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
459
+
460
+ assert logp_parts [0 ].type .shape == (3 ,)
461
+ assert logp_parts [1 ].type .shape == (3 ,)
462
+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
463
+
464
+ x_parts_test = pytensor .function ([rng_pt ], x_parts )(rng )
465
+ logp_x1_eval , logp_x2_eval = logp_fn (* x_parts_test )
466
+ np .testing .assert_allclose (logp_x1_eval * 3 , logp_x2_eval * 2 )
467
+ logp_total = logp_x1_eval + logp_x2_eval
468
+ np .testing .assert_allclose (
469
+ logp_total ,
470
+ scipy_dirichlet_logpdf (np .concatenate (x_parts_test , axis = 1 ), alpha ),
471
+ )
472
+
473
+ @pytest .mark .xfail (
474
+ reason = "Rewrite from partial split to split on subtensor not implemented yet"
475
+ )
476
+ def test_not_all_splits_used (self ):
477
+ x = pt .random .normal (mu = pt .arange (6 ), name = "x" )
478
+ x_parts = pt .split (x , splits_size = [2 , 2 , 2 ], n_splits = 3 , axis = 0 )[
479
+ ::2
480
+ ] # Only use first two splits
481
+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
482
+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
483
+ assert len (logp_parts ) == 2
484
+
485
+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
486
+ x_parts_test = [x_part .eval () for x_part in x_parts_vv ]
487
+ logp_x1_eval , logp_x2_eval = logp_fn (* x_parts_test )
488
+ np .testing .assert_allclose (
489
+ logp_x1_eval ,
490
+ st .norm .logpdf (x_parts_test [0 ], loc = [0 , 1 ]),
491
+ )
492
+ np .testing .assert_allclose (
493
+ logp_x2_eval ,
494
+ st .norm .logpdf (x_parts_test [1 ], loc = [4 , 5 ]),
495
+ )
496
+
497
+ def test_not_all_splits_used_core_dim (self ):
498
+ # TODO: We could support this for univariate/batch dimensions by rewriting as
499
+ # split(x, splits_size=[2, 2, 2], n_splits=3, axis=1)[:2] -> split(x[:-2], splits_size=[2, 2], n_splits=2, axis=1)
500
+ # And letting logp infer the probability of x[:-2]
501
+ x = pt .random .dirichlet (alphas = pt .ones (6 ), name = "x" )
502
+ x_parts = pt .split (x , splits_size = [2 , 2 , 2 ], n_splits = 3 , axis = 0 )[
503
+ :2
504
+ ] # Only use first two splits
505
+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
506
+
507
+ with pytest .raises (
508
+ ValueError ,
509
+ match = "Split logp requires the number of values to match the number of splits" ,
510
+ ):
511
+ conditional_logp (dict (zip (x_parts , x_parts_vv )))
512
+
513
+ @pytest .mark .xfail (reason = "Rewrite from subtensor to split not implemented yet" )
514
+ def test_subtensor_converted_to_splits (self ):
515
+ rng = np .random .default_rng (388 )
516
+ x = pt .random .normal (mu = pt .arange (5 ), name = "x" )
517
+
518
+ x_parts = [x [:2 ], x [2 :3 ], x [3 :]]
519
+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
520
+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
521
+ assert len (logp_parts ) == 3
522
+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
523
+ x_parts_test = [rng .normal (size = x_part .type .shape ) for x_part in x_parts_vv ]
524
+ logp_x1_eval , logp_x2_eval , logp_x3_eval = logp_fn (* x_parts_test )
525
+ np .testing .assert_allclose (logp_x1_eval , st .norm .logpdf (x_parts_test [0 ], loc = [0 , 1 ]))
526
+ np .testing .assert_allclose (logp_x2_eval , st .norm .logpdf (x_parts_test [1 ], loc = [2 ]))
527
+ np .testing .assert_allclose (logp_x3_eval , st .norm .logpdf (x_parts_test [2 ], loc = [3 , 4 ]))
0 commit comments