Skip to content

Commit 3c0a99a

Browse files
andrewkernbenjeffery
authored andcommitted
refactor: add k_values parameter to TestTraitLinearModel.example_covariates
- Modified example_covariates() to accept optional k_values parameter - Default k_values=[2] maintains current optimization performance - Updated verify() method to explicitly pass k_values=[2] - Removed unused trait_covariate_cache fixture This change makes the test method more flexible while preserving the performance optimization from commit 4653b6e that reduced test execution time from 42s to 35s.
1 parent 08e16e3 commit 3c0a99a

File tree

1 file changed

+19
-59
lines changed

1 file changed

+19
-59
lines changed

python/tests/test_tree_stats.py

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -803,51 +803,6 @@ def wf_mut_sim_fixture():
803803
return simulations
804804

805805

806-
# Fixture for TraitLinearModel matrix operations
807-
@pytest.fixture(scope="session")
808-
def trait_covariate_cache():
809-
"""Cache expensive matrix operations for TraitLinearModel tests."""
810-
cache = {}
811-
np.random.seed(999) # Same seed as example_covariates
812-
813-
# Pre-compute for common sample sizes
814-
for N in [6, 10, 12]: # Most common test sizes
815-
k = min(2, N) # We now only use k=2
816-
817-
# Uniform covariates transform
818-
Z = np.ones((N, k))
819-
Z[1, :] = np.arange(k, 2 * k)
820-
tZ = np.column_stack([Z, np.ones((N, 1))])
821-
if np.linalg.matrix_rank(tZ) == tZ.shape[1]:
822-
Z_full = tZ
823-
else:
824-
Z_full = Z
825-
826-
if np.linalg.matrix_rank(Z_full) == Z_full.shape[1]:
827-
K = np.linalg.cholesky(np.matmul(Z_full.T, Z_full)).T
828-
Z_transformed = np.matmul(Z_full, np.linalg.inv(K))
829-
cache[(N, "uniform")] = Z_transformed
830-
831-
# Normal covariates transform (only for N >= 6)
832-
if N >= 6:
833-
Z = np.ones((N, k))
834-
for j in range(k):
835-
Z[:, j] = np.random.normal(0, 1, N)
836-
837-
tZ = np.column_stack([Z, np.ones((N, 1))])
838-
if np.linalg.matrix_rank(tZ) == tZ.shape[1]:
839-
Z_full = tZ
840-
else:
841-
Z_full = Z
842-
843-
if np.linalg.matrix_rank(Z_full) == Z_full.shape[1]:
844-
K = np.linalg.cholesky(np.matmul(Z_full.T, Z_full)).T
845-
Z_transformed = np.matmul(Z_full, np.linalg.inv(K))
846-
cache[(N, "normal")] = Z_transformed
847-
848-
return cache
849-
850-
851806
class MutatedTopologyExamplesMixin:
852807
"""
853808
Defines a set of test cases on different example tree sequence topologies.
@@ -5588,23 +5543,28 @@ def get_example_ts(self, ts_10_mut_recomb_fixture):
55885543
assert ts.num_mutations > 0
55895544
return ts
55905545

5591-
def example_covariates(self, ts):
5546+
def example_covariates(self, ts, k_values=None):
5547+
if k_values is None:
5548+
k_values = [2] # Default to [2] to maintain current optimization
5549+
55925550
np.random.seed(999)
55935551
N = ts.num_samples
5594-
# Reduced combinations for performance: only k=2 instead of [1, 2, 5]
5595-
k = min(2, ts.num_samples)
5596-
5597-
# Uniform covariates
5598-
Z = np.ones((N, k))
5599-
Z[1, :] = np.arange(k, 2 * k)
5600-
yield Z
5601-
5602-
# Include one normal case for test coverage
5603-
if N >= 6: # Only for larger samples to reduce computations
5604-
for j in range(k):
5605-
Z[:, j] = np.random.normal(0, 1, N)
5552+
5553+
for k in k_values:
5554+
k = min(k, ts.num_samples)
5555+
5556+
# Uniform covariates
5557+
Z = np.ones((N, k))
5558+
Z[1, :] = np.arange(k, 2 * k)
56065559
yield Z
56075560

5561+
# Include one normal case for test coverage
5562+
if N >= 6: # Only for larger samples to reduce computations
5563+
Z_normal = np.ones((N, k))
5564+
for j in range(k):
5565+
Z_normal[:, j] = np.random.normal(0, 1, N)
5566+
yield Z_normal
5567+
56085568
def transform_weights(self, W, Z):
56095569
n = W.shape[0]
56105570
return np.column_stack([W, Z, np.ones((n, 1))])
@@ -5621,7 +5581,7 @@ def transform_covariates(self, Z):
56215581
def verify(self, ts):
56225582
for W, Z, windows in subset_combos(
56235583
self.example_weights(ts),
5624-
self.example_covariates(ts),
5584+
self.example_covariates(ts, k_values=[2]),
56255585
example_windows(ts),
56265586
p=0.02, # Reduced from 0.04 for performance
56275587
):

0 commit comments

Comments
 (0)