Skip to content

Commit 985ff4c

Browse files
authored
Merge pull request #10 from tmcclintock/feat/type-checking
Feat/type checking
2 parents 8a33bc7 + 95f1d4d commit 985ff4c

File tree

4 files changed

+17
-11
lines changed

4 files changed

+17
-11
lines changed

.github/workflows/ci.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@ jobs:
2424
- name: Install the project and dependencies
2525
run: uv sync --all-extras
2626

27-
- uses: astral-sh/ruff-action@v3
27+
- name: Lint
28+
uses: astral-sh/ruff-action@v3
29+
30+
- name: Type checkg
31+
run: |
32+
uv tool install ty
33+
ty check src
34+
ty check tests
2835
2936
- name: Run tests with coverage
3037
run: uv run coverage run -m pytest

src/ConditionalGMM/UnivariateGMM.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Helpful functions that can only be computed (easily) for univarate GMMs."""
22

33
import numpy as np
4-
import scipy as sp
4+
from scipy.optimize import newton
5+
from scipy.stats import norm
56

67

78
class UniGMM:
@@ -49,7 +50,7 @@ def pdf(self, x):
4950
assert np.ndim(x) < 2
5051
# TODO vectorize
5152
pdfs = np.array(
52-
[sp.stats.norm.pdf(x, mi, vi) for mi, vi in zip(self.means, self.variances)]
53+
[norm.pdf(x, mi, vi) for mi, vi in zip(self.means, self.variances)]
5354
)
5455
return np.dot(self.weights, pdfs)
5556

@@ -76,7 +77,7 @@ def cdf(self, x):
7677
7778
"""
7879
cdfs = np.array(
79-
[sp.stats.norm.cdf(x, mi, vi) for mi, vi in zip(self.means, self.variances)]
80+
[norm.cdf(x, mi, vi) for mi, vi in zip(self.means, self.variances)]
8081
)
8182
return np.dot(self.weights, cdfs)
8283

@@ -103,9 +104,7 @@ def ppf(self, q):
103104
(float or array-like) quantile corresponding to `q`
104105
105106
"""
106-
return sp.optimize.newton(
107-
func=lambda x: self.cdf(x) - q, x0=self.mean(), fprime=self.pdf
108-
)
107+
return newton(func=lambda x: self.cdf(x) - q, x0=self.mean(), fprime=self.pdf)
109108

110109
def mean(self):
111110
"""Mean of the RV for the GMM.

src/ConditionalGMM/condGMM.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Conditional Gaussian mixture model."""
22

33
import numpy as np
4-
import scipy as sp
4+
from scipy.stats import multivariate_normal
55

66
from ConditionalGMM.MNorm import CondMNorm
77

@@ -103,7 +103,7 @@ def unconditional_pdf_x2(self, x2=None, component_probs=False):
103103

104104
probs = w * np.array(
105105
[
106-
sp.stats.multivariate_normal.pdf(x2, mean=mus[i], cov=covs[i])
106+
multivariate_normal.pdf(x2, mean=mus[i], cov=covs[i])
107107
for i in range(len(w))
108108
]
109109
)

tests/test_UGMM.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import numpy.testing as npt
5-
import scipy as sp
5+
from scipy.stats import norm
66

77
from ConditionalGMM import UnivariateGMM
88

@@ -25,7 +25,7 @@ def test_uggm_pdf():
2525
pdf = ugmm.pdf(x)
2626
truepdf = np.dot(
2727
np.array(weights),
28-
np.array([sp.stats.norm.pdf(x, mi, vi) for mi, vi in zip(means, vars)]),
28+
np.array([norm.pdf(x, mi, vi) for mi, vi in zip(means, vars)]),
2929
)
3030
npt.assert_equal(pdf, truepdf)
3131

0 commit comments

Comments
 (0)