|  | 
| 14 | 14 | 
 | 
| 15 | 15 | import warnings | 
| 16 | 16 | 
 | 
|  | 17 | +from typing import List | 
|  | 18 | + | 
| 17 | 19 | import numpy as np | 
| 18 | 20 | import theano.tensor as tt | 
| 19 | 21 | 
 | 
| @@ -565,3 +567,78 @@ def jacobian_det(self, y): | 
| 565 | 567 |             else: | 
| 566 | 568 |                 det += det_ | 
| 567 | 569 |         return det | 
|  | 570 | + | 
|  | 571 | + | 
|  | 572 | +def _extend_axis(array, axis): | 
|  | 573 | +    n = array.shape[axis] + 1 | 
|  | 574 | +    sum_vals = array.sum(axis, keepdims=True) | 
|  | 575 | +    norm = sum_vals / (np.sqrt(n) + n) | 
|  | 576 | +    fill_val = norm - sum_vals / np.sqrt(n) | 
|  | 577 | + | 
|  | 578 | +    out = tt.concatenate([array, fill_val], axis=axis) | 
|  | 579 | +    return out - norm | 
|  | 580 | + | 
|  | 581 | + | 
|  | 582 | +def _extend_axis_rev(array, axis): | 
|  | 583 | +    if axis < 0: | 
|  | 584 | +        axis = axis % array.ndim | 
|  | 585 | +    assert axis >= 0 and axis < array.ndim | 
|  | 586 | + | 
|  | 587 | +    n = array.shape[axis] | 
|  | 588 | +    last = tt.take(array, [-1], axis=axis) | 
|  | 589 | + | 
|  | 590 | +    sum_vals = -last * np.sqrt(n) | 
|  | 591 | +    norm = sum_vals / (np.sqrt(n) + n) | 
|  | 592 | +    slice_before = (slice(None, None),) * axis | 
|  | 593 | +    return array[slice_before + (slice(None, -1),)] + norm | 
|  | 594 | + | 
|  | 595 | + | 
|  | 596 | +def _extend_axis_val(array, axis): | 
|  | 597 | +    n = array.shape[axis] + 1 | 
|  | 598 | +    sum_vals = array.sum(axis, keepdims=True) | 
|  | 599 | +    norm = sum_vals / (np.sqrt(n) + n) | 
|  | 600 | +    fill_val = norm - sum_vals / np.sqrt(n) | 
|  | 601 | + | 
|  | 602 | +    out = np.concatenate([array, fill_val], axis=axis) | 
|  | 603 | +    return out - norm | 
|  | 604 | + | 
|  | 605 | + | 
|  | 606 | +def _extend_axis_rev_val(array, axis): | 
|  | 607 | +    n = array.shape[axis] | 
|  | 608 | +    last = np.take(array, [-1], axis=axis) | 
|  | 609 | + | 
|  | 610 | +    sum_vals = -last * np.sqrt(n) | 
|  | 611 | +    norm = sum_vals / (np.sqrt(n) + n) | 
|  | 612 | +    slice_before = (slice(None, None),) * len(array.shape[:axis]) | 
|  | 613 | +    return array[slice_before + (slice(None, -1),)] + norm | 
|  | 614 | + | 
|  | 615 | + | 
|  | 616 | +class ZeroSumTransform(Transform): | 
|  | 617 | +    name = "zerosum" | 
|  | 618 | + | 
|  | 619 | +    _zerosum_axes: List[int] | 
|  | 620 | + | 
|  | 621 | +    def __init__(self, zerosum_axes): | 
|  | 622 | +        self._zerosum_axes = zerosum_axes | 
|  | 623 | + | 
|  | 624 | +    def forward(self, x): | 
|  | 625 | +        for axis in self._zerosum_axes: | 
|  | 626 | +            x = _extend_axis_rev(x, axis=axis) | 
|  | 627 | +        return floatX(x) | 
|  | 628 | + | 
|  | 629 | +    def forward_val(self, x, point): | 
|  | 630 | +        for axis in self._zerosum_axes: | 
|  | 631 | +            x = _extend_axis_rev_val(x, axis=axis) | 
|  | 632 | +        return x | 
|  | 633 | + | 
|  | 634 | +    def backward(self, z): | 
|  | 635 | +        z = tt.as_tensor_variable(z) | 
|  | 636 | +        for axis in self._zerosum_axes: | 
|  | 637 | +            z = _extend_axis(z, axis=axis) | 
|  | 638 | +        return floatX(z) | 
|  | 639 | + | 
|  | 640 | +    def jacobian_det(self, x): | 
|  | 641 | +        return tt.constant(0.0) | 
|  | 642 | + | 
|  | 643 | + | 
|  | 644 | +zerosum = ZeroSumTransform | 
0 commit comments