-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathtests.py
More file actions
724 lines (553 loc) · 23.2 KB
/
tests.py
File metadata and controls
724 lines (553 loc) · 23.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
import sys
from functools import wraps
import jax
import jax.core
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
import pytest
from jax import lax
import klujax
from klujax import COMPLEX_DTYPES, coalesce
OPS_DENSE = { # sparse to dense
klujax.dot: lax.dot,
klujax.solve: jsp.linalg.solve,
}
def log_test_name(f):
@wraps(f)
def new(*args, **kwargs):
print(f"\n{f.__name__}", file=sys.stderr)
if args:
print(f"args={args}", file=sys.stderr)
if kwargs:
print(f"kwargs={kwargs}", file=sys.stderr)
return f(*args, **kwargs)
return new
def parametrize_dtypes(func):
return pytest.mark.parametrize("dtype", [np.float64, np.complex128])(func)
def parametrize_ops(func):
return pytest.mark.parametrize("op_sparse", [klujax.solve, klujax.dot])(func)
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_1d(dtype, op_sparse):
op_dense = OPS_DENSE[op_sparse]
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
x_sp = op_sparse(Ai, Aj, Ax, b)
A = jnp.zeros((n_col, n_col), dtype=Ax.dtype).at[Ai, Aj].add(Ax)
x = op_dense(A, b)
_log_and_test_equality(x, x_sp)
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_2d(dtype, op_sparse):
op_dense = jax.vmap(OPS_DENSE[op_sparse], (0, 0), 0)
Ai, Aj, Ax, b = _get_rand_arrs_2d((n_lhs := 3), 15, (n_col := 5), dtype=dtype)
x_sp = op_sparse(Ai, Aj, Ax, b)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
x = op_dense(A, b)
_log_and_test_equality(x, x_sp)
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_2d_vmap(dtype, op_sparse):
op_dense = OPS_DENSE[op_sparse]
Ai, Aj, Ax, b = _get_rand_arrs_2d((n_lhs := 3), 15, (n_col := 5), dtype=dtype)
x_sp = jax.vmap(op_sparse, (None, None, 1, 1), 0)(Ai, Aj, Ax.T, b.T)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
x = jax.vmap(op_dense, (0, 0), 0)(A, b)
_log_and_test_equality(x, x_sp)
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_3d(dtype, op_sparse):
op_dense = jax.vmap(OPS_DENSE[op_sparse], (0, 0), 0)
Ai, Aj, Ax, b = _get_rand_arrs_3d((n_lhs := 2), 8, (n_col := 3), 4, dtype=dtype)
x_sp = op_sparse(Ai, Aj, Ax, b)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
x = op_dense(A, b)
_log_and_test_equality(x, x_sp)
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_3d_jacfwd(dtype, op_sparse):
op_dense = jax.vmap(OPS_DENSE[op_sparse], (0, 0), 0)
Ai, Aj, Ax, b = _get_rand_arrs_3d((n_lhs := 3), 15, (n_col := 5), 2, dtype=dtype)
holomorphic = dtype in COMPLEX_DTYPES
# jacobian on b
jac_sp = jax.jacfwd(op_sparse, 3, holomorphic=holomorphic)(Ai, Aj, Ax, b)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
jac = jax.jacfwd(op_dense, 1, holomorphic=holomorphic)(A, b)
_log_and_test_equality(jac_sp, jac)
# jacobian on A
jac_sp = jax.jacfwd(op_sparse, 2, holomorphic=holomorphic)(Ai, Aj, Ax, b)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
jac = jax.jacfwd(op_dense, 0, holomorphic=holomorphic)(A, b)[..., Ai, Aj]
_log_and_test_equality(jac_sp, jac)
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_3d_jacrev(dtype, op_sparse):
op_dense = jax.vmap(OPS_DENSE[op_sparse], (0, 0), 0)
Ai, Aj, Ax, b = _get_rand_arrs_3d((n_lhs := 3), 15, (n_col := 5), 2, dtype=dtype)
holomorphic = dtype in COMPLEX_DTYPES
# jacobian on b
jac_sp = jax.jacrev(op_sparse, 3, holomorphic=holomorphic)(Ai, Aj, Ax, b)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
jac = jax.jacrev(op_dense, 1, holomorphic=holomorphic)(A, b)
_log_and_test_equality(jac_sp, jac)
# jacobian on A
jac_sp = jax.jacrev(op_sparse, 2, holomorphic=holomorphic)(Ai, Aj, Ax, b)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
jac = jax.jacrev(op_dense, 0, holomorphic=holomorphic)(A, b)[..., Ai, Aj]
_log_and_test_equality(jac_sp, jac)
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_3d_vmap(dtype, op_sparse):
op_dense = OPS_DENSE[op_sparse]
Ai, Aj, Ax, b = _get_rand_arrs_3d((n_lhs := 3), 15, (n_col := 5), 2, dtype=dtype)
_log(Ai_shape=Ai.shape, Aj_shape=Aj.shape, Ax_shape=Ax.shape, b_shape=b.shape)
x_sp = jax.vmap(op_sparse, (None, None, None, -1), -1)(Ai, Aj, Ax, b)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
x = jax.vmap(op_dense, (0, 0), 0)(A, b)
_log_and_test_equality(x, x_sp)
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_4d(dtype, op_sparse):
Ai, Aj, Ax, b = _get_rand_arrs_3d(3, 15, 5, 2, dtype=dtype)
with pytest.raises(ValueError): # noqa: PT011
op_sparse(Ai, Aj, Ax, b[None])
with pytest.raises(ValueError): # noqa: PT011
op_sparse(Ai, Aj, Ax[None], b)
with pytest.raises(ValueError): # noqa: PT011
op_sparse(Ai, Aj, Ax[None], b[None])
@log_test_name
@parametrize_dtypes
@parametrize_ops
def test_4d_vmap(dtype, op_sparse):
Ai, Aj, Ax, b = _get_rand_arrs_3d(3, 8, 5, 2, dtype=dtype)
bb = np.stack([b, (b2 := np.random.RandomState(seed=42).rand(*b.shape))], axis=1)
r = jax.vmap(op_sparse, (None, None, None, 1), 0)(Ai, Aj, Ax, bb)
r1 = op_sparse(Ai, Aj, Ax, b)
r2 = op_sparse(Ai, Aj, Ax, b2)
_log_and_test_equality(r[0], r1)
_log_and_test_equality(r[1], r2)
@log_test_name
def test_analyze():
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=np.float64)
# 1. Test Eager Analysis
symbolic = klujax.analyze(Ai, Aj, n_col)
assert isinstance(symbolic, klujax.KLUHandleManager)
assert symbolic._owner is True
assert symbolic.handle.dtype == jnp.uint64
# Manually free to be clean before next step
klujax.free_symbolic(symbolic)
assert symbolic._freed is True
@jax.jit
def jit_analyze_and_solve(Ai, Aj, Ax, b):
# Create handle inside JIT
sym = klujax.analyze(Ai, Aj, 5) # Inside JIT, this is a Tracer
x = klujax.solve_with_symbol(Ai, Aj, Ax, b, sym)
klujax.free_symbolic(sym, dependency=x)
return x
x = jit_analyze_and_solve(Ai, Aj, Ax, b)
assert x.shape == (n_col,)
@log_test_name
@parametrize_dtypes
def test_solve_with_symbol(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
x_sp = klujax.solve_with_symbol(Ai, Aj, Ax, b, symbolic)
A = jnp.zeros((n_col, n_col), dtype=dtype).at[Ai, Aj].add(Ax)
x = jsp.linalg.solve(A, b)
_log_and_test_equality(x, x_sp)
# Test JIT
x_sp_jit = jax.jit(klujax.solve_with_symbol)(Ai, Aj, Ax, b, symbolic)
_log_and_test_equality(x, x_sp_jit)
klujax.free_symbolic(symbolic)
@log_test_name
@parametrize_dtypes
def test_solve_with_symbol_batched(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_2d((n_lhs := 3), 15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
x_sp = klujax.solve_with_symbol(Ai, Aj, Ax, b, symbolic)
# Dense verification
op_dense = jax.vmap(jsp.linalg.solve, (0, 0), 0)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
x = op_dense(A, b)
_log_and_test_equality(x, x_sp)
klujax.free_symbolic(symbolic)
@log_test_name
@parametrize_dtypes
def test_solve_with_numeric(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
numeric = klujax.factor(Ai, Aj, Ax, symbolic)
x_sp = klujax.solve_with_numeric(numeric, b, symbolic)
A = jnp.zeros((n_col, n_col), dtype=dtype).at[Ai, Aj].add(Ax)
x = jsp.linalg.solve(A, b)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(numeric)
klujax.free_symbolic(symbolic)
@log_test_name
@parametrize_dtypes
def test_solve_with_numeric_batched(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_2d((n_lhs := 3), 15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
numeric = klujax.factor(Ai, Aj, Ax, symbolic)
x_sp = klujax.solve_with_numeric(numeric, b, symbolic)
# Dense verification
op_dense = jax.vmap(jsp.linalg.solve, (0, 0), 0)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
x = op_dense(A, b)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(numeric)
klujax.free_symbolic(symbolic)
@log_test_name
@parametrize_dtypes
def test_solve_with_numeric_vmap_1d_b(dtype):
"""vmap over solve_with_numeric where b is 1D (single-system, single-RHS)."""
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
# Create a batch of different Ax values (different matrix values, same sparsity)
key = jax.random.PRNGKey(42)
batch = 4
Ax_batch = jax.random.normal(key, (batch, *Ax.shape), dtype=dtype) + 10.0
def solve_one(Ax_i):
num = klujax.factor(Ai, Aj, Ax_i, symbolic)
return klujax.solve_with_numeric(num, b, symbolic)
x_sp = jax.vmap(solve_one)(Ax_batch)
# Dense reference
A_batch = jnp.zeros((batch, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax_batch)
x = jax.vmap(lambda A: jsp.linalg.solve(A, b))(A_batch)
_log_and_test_equality(x, x_sp)
klujax.free_symbolic(symbolic)
def _get_rand_arrs_1d(n_nz, n_col, *, dtype, seed=33):
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(seed), 4)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
Ax = jax.random.normal(Axkey, (n_nz,), dtype=dtype)
# Add diagonal to ensure matrix is invertible
diag_i = jnp.arange(n_col, dtype=jnp.int32)
diag_x = jnp.ones(n_col, dtype=dtype) * 10.0
Ai = jnp.concatenate([Ai, diag_i])
Aj = jnp.concatenate([Aj, diag_i])
Ax = jnp.concatenate([Ax, diag_x])
Ai, Aj, Ax = coalesce(Ai, Aj, Ax)
b = jax.random.normal(bkey, (n_col,), dtype=dtype)
return Ai, Aj, Ax, b
def _get_rand_arrs_2d(n_lhs, n_nz, n_col, *, dtype, seed=33):
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(seed), 4)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
Ax = jax.random.normal(
Axkey,
(
n_lhs,
n_nz,
),
dtype=dtype,
)
# Add diagonal to ensure matrix is invertible
diag_i = jnp.arange(n_col, dtype=jnp.int32)
diag_x = jnp.ones((n_lhs, n_col), dtype=dtype) * 10.0
Ai = jnp.concatenate([Ai, diag_i])
Aj = jnp.concatenate([Aj, diag_i])
Ax = jnp.concatenate([Ax, diag_x], axis=1)
Ai, Aj, Ax = coalesce(Ai, Aj, Ax)
b = jax.random.normal(bkey, (n_lhs, n_col), dtype=dtype)
return Ai, Aj, Ax, b
def _get_rand_arrs_3d(n_lhs, n_nz, n_col, n_rhs, *, dtype, seed=33):
Axkey, Aikey, Ajkey, bkey = jax.random.split(jax.random.PRNGKey(seed), 4)
Ai = jax.random.randint(Aikey, (n_nz,), 0, n_col, jnp.int32)
Aj = jax.random.randint(Ajkey, (n_nz,), 0, n_col, jnp.int32)
Ax = jax.random.normal(
Axkey,
(
n_lhs,
n_nz,
),
dtype=dtype,
)
# Add diagonal to ensure matrix is invertible
diag_i = jnp.arange(n_col, dtype=jnp.int32)
diag_x = jnp.ones((n_lhs, n_col), dtype=dtype) * 10.0
Ai = jnp.concatenate([Ai, diag_i])
Aj = jnp.concatenate([Aj, diag_i])
Ax = jnp.concatenate([Ax, diag_x], axis=1)
Ai, Aj, Ax = coalesce(Ai, Aj, Ax)
b = jax.random.normal(bkey, (n_lhs, n_col, n_rhs), dtype=dtype)
return Ai, Aj, Ax, b
def _log_and_test_equality(x, x_sp):
print(f"\nx=\n{x}")
print(f"\nx_sp=\n{x_sp}")
print(f"\ndiff=\n{np.round(x_sp - x, 9)}")
print(f"\nis_equal=\n{_is_almost_equal(x, x_sp)}")
np.testing.assert_array_almost_equal(x, x_sp)
def _log(**kwargs):
for k, v in kwargs.items():
print(f"{k}={v}")
def _is_almost_equal(arr1, arr2):
try:
np.testing.assert_array_almost_equal(arr1, arr2)
except AssertionError:
return False
else:
return True
def _make_ax2(Ai, Aj, Ax, *, dtype, seed=42):
"""New Ax values with the same sparsity pattern as Ax, with a strong diagonal."""
Ax_raw = jax.random.normal(jax.random.PRNGKey(seed), Ax.shape, dtype=dtype)
# Ensure diagonal entries are large (invertible) — Ai[k]==Aj[k] marks diagonal positions
if Ax_raw.ndim == 1:
Ax_raw = Ax_raw + jnp.where(
Ai == Aj, jnp.full_like(Ax_raw, 10.0), jnp.zeros_like(Ax_raw)
)
else: # (n_lhs, n_nz) batch case
diag_mask = (Ai == Aj)[None, :]
Ax_raw = Ax_raw + jnp.where(
diag_mask, jnp.full_like(Ax_raw, 10.0), jnp.zeros_like(Ax_raw)
)
return Ax_raw
@log_test_name
@parametrize_dtypes
def test_refactor(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
Ax2 = _make_ax2(Ai, Aj, Ax, dtype=dtype)
b2 = jax.random.normal(jax.random.PRNGKey(43), b.shape, dtype=dtype)
sym = klujax.analyze(Ai, Aj, n_col)
num = klujax.factor(Ai, Aj, Ax, sym)
num2 = klujax.refactor(Ai, Aj, Ax2, num, sym)
assert isinstance(num2, klujax.KLUHandleManager)
x_sp = klujax.solve_with_numeric(num2, b2, sym)
A2 = jnp.zeros((n_col, n_col), dtype=dtype).at[Ai, Aj].add(Ax2)
x = jsp.linalg.solve(A2, b2)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(num)
klujax.free_symbolic(sym)
@log_test_name
@parametrize_dtypes
def test_refactor_batched(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_2d((n_lhs := 3), 15, (n_col := 5), dtype=dtype)
Ax2 = _make_ax2(Ai, Aj, Ax, dtype=dtype)
b2 = jax.random.normal(jax.random.PRNGKey(43), b.shape, dtype=dtype)
sym = klujax.analyze(Ai, Aj, n_col)
num = klujax.factor(Ai, Aj, Ax, sym)
num2 = klujax.refactor(Ai, Aj, Ax2, num, sym)
assert isinstance(num2, klujax.KLUHandleManager)
x_sp = klujax.solve_with_numeric(num2, b2, sym)
op_dense = jax.vmap(jsp.linalg.solve, (0, 0), 0)
A2 = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax2)
x = op_dense(A2, b2)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(num)
klujax.free_symbolic(sym)
@log_test_name
@parametrize_dtypes
def test_refactor_vmap(dtype):
"""vmap solve_with_symbol over n_rhs with shared sym used by refactor.
Tests that vmap composes correctly with the symbolic analysis after refactor,
and that refactor does not corrupt the symbolic object.
"""
Ai, Aj, Ax, b = _get_rand_arrs_3d((n_lhs := 3), 15, (n_col := 5), 4, dtype=dtype)
Ax2 = _make_ax2(Ai, Aj, Ax, dtype=dtype)
sym = klujax.analyze(Ai, Aj, n_col)
num = klujax.factor(Ai, Aj, Ax, sym)
num2 = klujax.refactor(Ai, Aj, Ax2, num, sym)
# vmap solve_with_symbol over n_rhs axis (same pattern as test_3d_vmap)
x_sp = jax.vmap(klujax.solve_with_symbol, (None, None, None, -1, None), -1)(
Ai, Aj, Ax2, b, sym
)
A2 = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax2)
x = jax.vmap(jax.vmap(jsp.linalg.solve, (0, 0), 0), (None, -1), -1)(A2, b)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(num2)
klujax.free_symbolic(sym)
@log_test_name
@parametrize_dtypes
def test_refactor_pmap(dtype):
"""pmap solve_with_numeric across devices after refactor."""
n_dev = jax.device_count()
Ai, Aj, Ax, _ = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
Ax2 = _make_ax2(Ai, Aj, Ax, dtype=dtype)
B = jax.random.normal(jax.random.PRNGKey(77), (n_dev, n_col), dtype=dtype)
sym = klujax.analyze(Ai, Aj, n_col)
num = klujax.factor(Ai, Aj, Ax, sym)
num2 = klujax.refactor(Ai, Aj, Ax2, num, sym)
x_sp = jax.pmap(lambda b: klujax.solve_with_numeric(num2, b, sym))(B)
A2 = jnp.zeros((n_col, n_col), dtype=dtype).at[Ai, Aj].add(Ax2)
x = jax.vmap(lambda b: jsp.linalg.solve(A2, b))(B)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(num2)
klujax.free_symbolic(sym)
@log_test_name
@parametrize_dtypes
def test_refactor_grad(dtype):
"""AD through solve_with_symbol is unaffected by refactor on the shared symbolic analysis."""
Ai, Aj, Ax, _ = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
Ax2 = _make_ax2(Ai, Aj, Ax, dtype=dtype)
b2 = jax.random.normal(jax.random.PRNGKey(43), (n_col,), dtype=dtype)
holomorphic = dtype in COMPLEX_DTYPES
sym = klujax.analyze(Ai, Aj, n_col)
num = klujax.factor(Ai, Aj, Ax, sym)
num2 = klujax.refactor(Ai, Aj, Ax2, num, sym)
# refactor value and solve_with_symbol value must agree
x_num = klujax.solve_with_numeric(num2, b2, sym)
x_sym = klujax.solve_with_symbol(Ai, Aj, Ax2, b2, sym)
_log_and_test_equality(x_num, x_sym)
# jacfwd through solve_with_symbol w.r.t. Ax: tests that sym is not corrupted by refactor
jac_sp = jax.jacfwd(
lambda ax: klujax.solve_with_symbol(Ai, Aj, ax, b2, sym),
holomorphic=holomorphic,
)(Ax2)
jac_dense = jax.jacfwd(
lambda ax: jnp.linalg.solve(
jnp.zeros((n_col, n_col), dtype=dtype).at[Ai, Aj].add(ax), b2
),
holomorphic=holomorphic,
)(Ax2)
_log_and_test_equality(jac_sp, jac_dense)
klujax.free_numeric(num2)
klujax.free_symbolic(sym)
def test_solve_with_symbol_jvp():
Ai = jnp.array([0, 1], dtype=jnp.int32)
Aj = jnp.array([0, 1], dtype=jnp.int32)
Ax = jnp.array([2.0, 4.0], dtype=jnp.float64) # Values
b = jnp.array([10.0, 20.0], dtype=jnp.float64)
# Pre-compute symbolic factorization
n_col = 2
symbolic = klujax.analyze(Ai, Aj, n_col)
def solve_step(Ax_vals, b_vec):
return klujax.solve_with_symbol(Ai, Aj, Ax_vals, b_vec, symbolic)
# Tangents (perturbations)
tangent_Ax = jnp.array([0.1, 0.1], dtype=jnp.float64)
tangent_b = jnp.array([0.0, 0.0], dtype=jnp.float64)
# This triggers the JVP rule lookup
primals, tangents = jax.jvp(solve_step, (Ax, b), (tangent_Ax, tangent_b))
assert primals.shape == (2,)
assert tangents.shape == (2,)
@log_test_name
@parametrize_dtypes
def test_tsolve_with_symbol(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
x_sp = klujax.tsolve_with_symbol(Ai, Aj, Ax, b, symbolic)
# For tsolve, we compare against the transposed dense matrix A.T
A = jnp.zeros((n_col, n_col), dtype=dtype).at[Ai, Aj].add(Ax)
x = jsp.linalg.solve(A.T, b)
_log_and_test_equality(x, x_sp)
# Test JIT
x_sp_jit = jax.jit(klujax.tsolve_with_symbol)(Ai, Aj, Ax, b, symbolic)
_log_and_test_equality(x, x_sp_jit)
klujax.free_symbolic(symbolic)
@log_test_name
@parametrize_dtypes
def test_tsolve_with_symbol_batched(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_2d((n_lhs := 3), 15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
x_sp = klujax.tsolve_with_symbol(Ai, Aj, Ax, b, symbolic)
# Dense verification for batched transpose
op_dense = jax.vmap(jsp.linalg.solve, (0, 0), 0)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
A_T = jnp.swapaxes(A, -1, -2)
x = op_dense(A_T, b)
_log_and_test_equality(x, x_sp)
klujax.free_symbolic(symbolic)
@log_test_name
@parametrize_dtypes
def test_tsolve_with_numeric(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
numeric = klujax.factor(Ai, Aj, Ax, symbolic)
x_sp = klujax.tsolve_with_numeric(numeric, b, symbolic)
A = jnp.zeros((n_col, n_col), dtype=dtype).at[Ai, Aj].add(Ax)
x = jsp.linalg.solve(A.T, b)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(numeric)
klujax.free_symbolic(symbolic)
@log_test_name
@parametrize_dtypes
def test_tsolve_with_numeric_batched(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_2d((n_lhs := 3), 15, (n_col := 5), dtype=dtype)
symbolic = klujax.analyze(Ai, Aj, n_col)
numeric = klujax.factor(Ai, Aj, Ax, symbolic)
x_sp = klujax.tsolve_with_numeric(numeric, b, symbolic)
# Dense verification for batched transpose
op_dense = jax.vmap(jsp.linalg.solve, (0, 0), 0)
A = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax)
A_T = jnp.swapaxes(A, -1, -2)
x = op_dense(A_T, b)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(numeric)
klujax.free_symbolic(symbolic)
@log_test_name
@parametrize_dtypes
def test_refactor_and_solve(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_1d(15, (n_col := 5), dtype=dtype)
Ax2 = _make_ax2(Ai, Aj, Ax, dtype=dtype)
b2 = jax.random.normal(jax.random.PRNGKey(43), b.shape, dtype=dtype)
sym = klujax.analyze(Ai, Aj, n_col)
num = klujax.factor(Ai, Aj, Ax, sym)
# Verifying specific API order: Ai, Aj, Ax, b, numeric, symbolic
x_sp, num2 = klujax.refactor_and_solve(Ai, Aj, Ax2, b2, num, sym)
assert isinstance(num2, klujax.KLUHandleManager)
assert num2._owner is False
A2 = jnp.zeros((n_col, n_col), dtype=dtype).at[Ai, Aj].add(Ax2)
x = jsp.linalg.solve(A2, b2)
_log_and_test_equality(x, x_sp)
# Original handle must be manually freed because num2 is owner=False
klujax.free_numeric(num)
klujax.free_symbolic(sym)
@log_test_name
@parametrize_dtypes
def test_refactor_and_solve_batched(dtype):
Ai, Aj, Ax, b = _get_rand_arrs_2d((n_lhs := 3), 15, (n_col := 5), dtype=dtype)
Ax2 = _make_ax2(Ai, Aj, Ax, dtype=dtype)
b2 = jax.random.normal(jax.random.PRNGKey(43), b.shape, dtype=dtype)
sym = klujax.analyze(Ai, Aj, n_col)
num = klujax.factor(Ai, Aj, Ax, sym)
# Verifying specific API order: Ai, Aj, Ax, b, numeric, symbolic
x_sp, num2 = klujax.refactor_and_solve(Ai, Aj, Ax2, b2, num, sym)
assert isinstance(num2, klujax.KLUHandleManager)
assert num2._owner is False
op_dense = jax.vmap(jsp.linalg.solve, (0, 0), 0)
A2 = jnp.zeros((n_lhs, n_col, n_col), dtype=dtype).at[:, Ai, Aj].add(Ax2)
x = op_dense(A2, b2)
_log_and_test_equality(x, x_sp)
klujax.free_numeric(num)
klujax.free_symbolic(sym)
# KLUHandleManager testing
def use_handle(manager, x):
"""Simulates any function requiring a concrete handle (e.g. a C pointer)."""
assert not isinstance(manager.handle, jax.core.Tracer), (
"Handle was traced! Got a Tracer instead of a concrete value."
)
return x * 2.0
def test_registration_traces_handle():
manager = klujax.KLUHandleManager(
jnp.array(0xDEADBEEF, dtype=jnp.int64), free_callable=lambda x: None
)
fn = jax.jit(use_handle)
result = fn(manager, jnp.array(1.0))
assert jnp.allclose(result, 2.0)
def test_registration_handle_concrete_under_grad():
manager = klujax.KLUHandleManager(
jnp.array(0xDEADBEEF, dtype=jnp.int64), free_callable=lambda x: None
)
fn = jax.grad(lambda x: use_handle(manager, x))
fn(jnp.array(1.0))
def test_handle_survives_pytree_roundtrip():
handle_val = jnp.array(0xDEADBEEF, dtype=jnp.int64)
manager = klujax.KLUHandleManager(handle_val, free_callable=lambda x: None)
leaves, treedef = jax.tree_util.tree_flatten(manager)
reconstructed = treedef.unflatten(leaves)
assert leaves == []
assert int(reconstructed.handle) == int(handle_val)
def test_registration_handle_concrete_under_vmap():
manager = klujax.KLUHandleManager(
jnp.array(0xDEADBEEF, dtype=jnp.int64), free_callable=lambda x: None
)
fn = jax.vmap(lambda x: use_handle(manager, x))
result = fn(jnp.ones((4,)))
assert jnp.allclose(result, 2.0)