Skip to content

Commit b11ea51

Browse files
authored
Merge pull request #53 from HERA-Team/ast-update
Updates for numpy 2
2 parents a1f9316 + e0a8831 commit b11ea51

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
name: Warnings Tests
2+
on: [push, pull_request]
3+
4+
jobs:
5+
tests:
6+
name: Warning Tests
7+
runs-on: ubuntu-latest
8+
9+
steps:
10+
- uses: actions/checkout@v3
11+
- name: Setup Python
12+
uses: actions/setup-python@v4
13+
with:
14+
python-version: '3.12'
15+
16+
- name: Install linsolve
17+
run: |
18+
python -m pip install --upgrade pip
19+
pip install -e ".[dev]"
20+
21+
- name: Run Tests
22+
run: |
23+
pytest -W error

src/linsolve/linsolve.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def ast_getterms(n):
4545
if type(n) is ast.Name:
4646
return [[n.id]]
4747
elif type(n) is ast.Constant or type(n) is ast.Constant:
48-
return [[n.n]]
48+
return [[n.value]]
4949
elif type(n) is ast.Expression:
5050
return ast_getterms(n.body)
5151
elif type(n) is ast.UnaryOp:
@@ -564,14 +564,20 @@ def _invert_solve(self, A, y, rcond):
564564
methods.
565565
"""
566566
# As of numpy 1.8, solve works on stacks of matrices
567+
# Change in numpy 2.0:
568+
# The b array is only treated as a shape (M,) column vector if it is
569+
# exactly 1-dimensional. In all other instances it is treated as a stack
570+
# of (M, K) matrices. Previously b would be treated as a stack of (M,)
571+
# vectors if b.ndim was equal to a.ndim - 1.
567572
At = A.transpose([2, 1, 0]).conj()
568573
AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])]
569-
Aty = [np.dot(At[k], y[..., k]) for k in range(y.shape[-1])]
574+
Aty = [np.dot(At[k], y[..., k])[:, None] for k in range(y.shape[-1])]
570575

571576
# This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her')
572577

573578
# But this sometimes errors if singular:
574-
return np.linalg.solve(AtA, Aty).T
579+
print(len(AtA), len(Aty), AtA[0].shape, Aty[0].shape)
580+
return np.linalg.solve(AtA, Aty).T[0]
575581

576582
def _invert_solve_sparse(self, xs_ys_vals, y, rcond):
577583
"""Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs.
@@ -690,7 +696,7 @@ def eval(self, sol, keys=None):
690696
def _chisq(self, sol, data, wgts, evaluator):
691697
"""Internal adaptable chisq calculator."""
692698
if len(wgts) == 0:
693-
sigma2 = {k: 1.0 for k in list(data.keys())} # equal weights
699+
sigma2 = dict.fromkeys(data.keys(), value=1.0) # equal weights
694700
else:
695701
sigma2 = {k: wgts[k] ** -1 for k in list(wgts.keys())}
696702
evaluated = evaluator(sol, keys=data)

tests/test_linsolve.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_eval(self):
141141

142142

143143
class TestLinearSolver:
144-
def setup(self):
144+
def setup_class(self):
145145
self.sparse = False
146146
eqs = ["x+y", "x-y"]
147147
x, y = 1, 2
@@ -366,7 +366,7 @@ def setup(self):
366366

367367

368368
class TestLogProductSolver:
369-
def setup(self):
369+
def setup_class(self):
370370
self.sparse = False
371371

372372
def test_init(self):
@@ -466,7 +466,7 @@ def setup(self):
466466

467467

468468
class TestLinProductSolver:
469-
def setup(self):
469+
def setup_class(self):
470470
self.sparse = False
471471

472472
def test_init(self):
@@ -490,7 +490,9 @@ def test_init(self):
490490
np.testing.assert_almost_equal(eval(k), 0.002)
491491
assert len(ls.ls.prms) == 3
492492

493-
ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse, build_solver=False)
493+
ls = linsolve.LinProductSolver(
494+
d, sol0, w, sparse=self.sparse, build_solver=False
495+
)
494496
assert not hasattr(ls, "ls")
495497
assert ls.dtype == np.complex64
496498

0 commit comments

Comments
 (0)