Skip to content

Commit ce23675

Browse files
committed
update ruff options + the fixes to comply
1 parent 7fece19 commit ce23675

22 files changed

+3731
-3730
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ repos:
3636
- id: ruff
3737
types_or: [ python, pyi, jupyter ]
3838
args: [ --fix ]
39+
# Exclude docs/ to avoid applying strict linting rules to example notebooks
40+
# Remove this exclusion if you want to enforce strict rules on documentation
41+
exclude: ^docs/
3942
# Run the formatter
4043
- id: ruff-format
4144
types_or: [ python, pyi, jupyter ]

causalpy/data/simulate_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def impact(x: np.ndarray) -> np.ndarray:
308308

309309
def generate_ancova_data(
310310
N: int = 200,
311-
pre_treatment_means: np.ndarray = np.array([10, 12]),
311+
pre_treatment_means: np.ndarray | None = None,
312312
treatment_effect: int = 2,
313313
sigma: int = 1,
314314
) -> pd.DataFrame:
@@ -324,6 +324,8 @@ def generate_ancova_data(
324324
... )
325325
>>> df.to_csv(pathlib.Path.cwd() / "ancova_data.csv", index=False) # doctest: +SKIP
326326
"""
327+
if pre_treatment_means is None:
328+
pre_treatment_means = np.array([10, 12])
327329
group = np.random.choice(2, size=N)
328330
pre = np.random.normal(loc=pre_treatment_means[group])
329331
post = pre + treatment_effect * group + np.random.normal(size=N) * sigma

causalpy/experiments/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717

1818
from abc import abstractmethod
19-
from typing import Any, Literal, Union
19+
from typing import Any, Literal
2020

2121
import arviz as az
2222
import matplotlib.pyplot as plt
@@ -54,7 +54,7 @@ class BaseExperiment:
5454
supports_bayes: bool
5555
supports_ols: bool
5656

57-
def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
57+
def __init__(self, model: PyMCModel | RegressorMixin | None = None) -> None:
5858
# Ensure we've made any provided Scikit Learn model (as identified as being type
5959
# RegressorMixin) compatible with CausalPy by appending our custom methods.
6060
if isinstance(model, RegressorMixin):
@@ -141,7 +141,7 @@ def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
141141

142142
def effect_summary(
143143
self,
144-
window: Union[Literal["post"], tuple, slice] = "post",
144+
window: Literal["post"] | tuple | slice = "post",
145145
direction: Literal["increase", "decrease", "two-sided"] = "increase",
146146
alpha: float = 0.05,
147147
cumulative: bool = True,

causalpy/experiments/diff_in_diff.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
Difference in differences
1616
"""
1717

18-
from typing import Union
19-
2018
import arviz as az
2119
import numpy as np
2220
import pandas as pd
@@ -98,7 +96,7 @@ def __init__(
9896
time_variable_name: str,
9997
group_variable_name: str,
10098
post_treatment_variable_name: str = "post_treatment",
101-
model: Union[PyMCModel, RegressorMixin] | None = None,
99+
model: PyMCModel | RegressorMixin | None = None,
102100
**kwargs: dict,
103101
) -> None:
104102
super().__init__(model=model)
@@ -234,7 +232,7 @@ def __init__(
234232
elif isinstance(self.model, RegressorMixin):
235233
# This is the coefficient on the interaction term
236234
# Store the coefficient into dictionary {intercept:value}
237-
coef_map = dict(zip(self.labels, self.model.get_coeffs()))
235+
coef_map = dict(zip(self.labels, self.model.get_coeffs(), strict=False))
238236
# Create and find the interaction term based on the values user provided
239237
interaction_term = (
240238
f"{self.group_variable_name}:{self.post_treatment_variable_name}"

causalpy/experiments/instrumental_variable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def input_validation(self) -> None:
161161
"""Warning. The treatment variable is not Binary.
162162
This is not necessarily a problem but it violates
163163
the assumption of a simple IV experiment.
164-
The coefficients should be interpreted appropriately."""
164+
The coefficients should be interpreted appropriately.""",
165+
UserWarning,
166+
stacklevel=2,
165167
)
166168

167169
def get_2SLS_fit(self) -> None:
@@ -195,7 +197,9 @@ def get_naive_OLS_fit(self) -> None:
195197
ols_reg = sk_lin_reg().fit(self.X, self.y)
196198
beta_params = list(ols_reg.coef_[0][1:])
197199
beta_params.insert(0, ols_reg.intercept_[0])
198-
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
200+
self.ols_beta_params = dict(
201+
zip(self._x_design_info.column_names, beta_params, strict=False)
202+
)
199203
self.ols_reg = ols_reg
200204

201205
def plot(self, *args, **kwargs) -> None: # type: ignore[override]

causalpy/experiments/interrupted_time_series.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Interrupted Time Series Analysis
1616
"""
1717

18-
from typing import Any, List, Union
18+
from typing import Any
1919

2020
import arviz as az
2121
import numpy as np
@@ -91,9 +91,9 @@ class InterruptedTimeSeries(BaseExperiment):
9191
def __init__(
9292
self,
9393
data: pd.DataFrame,
94-
treatment_time: Union[int, float, pd.Timestamp],
94+
treatment_time: int | float | pd.Timestamp,
9595
formula: str,
96-
model: Union[PyMCModel, RegressorMixin] | None = None,
96+
model: PyMCModel | RegressorMixin | None = None,
9797
**kwargs: dict,
9898
) -> None:
9999
super().__init__(model=model)
@@ -155,7 +155,7 @@ def __init__(
155155
# fit the model to the observed (pre-intervention) data
156156
if isinstance(self.model, PyMCModel):
157157
is_bsts_like = isinstance(
158-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
158+
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
159159
)
160160

161161
if is_bsts_like:
@@ -183,7 +183,7 @@ def __init__(
183183
# score the goodness of fit to the pre-intervention data
184184
if isinstance(self.model, PyMCModel):
185185
is_bsts_like = isinstance(
186-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
186+
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
187187
)
188188
if is_bsts_like:
189189
X_score = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
@@ -202,7 +202,7 @@ def __init__(
202202
# get the model predictions of the observed (pre-intervention) data
203203
if isinstance(self.model, PyMCModel):
204204
is_bsts_like = isinstance(
205-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
205+
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
206206
)
207207
if is_bsts_like:
208208
X_pre_predict = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
@@ -220,7 +220,7 @@ def __init__(
220220
# calculate the counterfactual (post period)
221221
if isinstance(self.model, PyMCModel):
222222
is_bsts_like = isinstance(
223-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
223+
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
224224
)
225225
if is_bsts_like:
226226
X_post_predict = (
@@ -244,7 +244,7 @@ def __init__(
244244
# calculate impact - use appropriate y data format for each model type
245245
if isinstance(self.model, PyMCModel):
246246
is_bsts_like = isinstance(
247-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
247+
self.model, BayesianBasisExpansionTimeSeries | StateSpaceTimeSeries
248248
)
249249
if is_bsts_like:
250250
pre_y_for_impact = self.pre_y.isel(treated_units=0)
@@ -275,7 +275,7 @@ def __init__(
275275
)
276276

277277
def input_validation(
278-
self, data: pd.DataFrame, treatment_time: Union[int, float, pd.Timestamp]
278+
self, data: pd.DataFrame, treatment_time: int | float | pd.Timestamp
279279
) -> None:
280280
"""Validate the input data and model formula for correctness"""
281281
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
@@ -303,7 +303,7 @@ def summary(self, round_to: int | None = None) -> None:
303303

304304
def _bayesian_plot(
305305
self, round_to: int | None = 2, **kwargs: dict
306-
) -> tuple[plt.Figure, List[plt.Axes]]:
306+
) -> tuple[plt.Figure, list[plt.Axes]]:
307307
"""
308308
Plot the results
309309
@@ -481,7 +481,7 @@ def _bayesian_plot(
481481

482482
def _ols_plot(
483483
self, round_to: int | None = 2, **kwargs: dict
484-
) -> tuple[plt.Figure, List[plt.Axes]]:
484+
) -> tuple[plt.Figure, list[plt.Axes]]:
485485
"""
486486
Plot the results
487487

causalpy/experiments/inverse_propensity_weighting.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
Inverse propensity weighting
1616
"""
1717

18-
from typing import List
19-
2018
import arviz as az
2119
import matplotlib.pyplot as plt
2220
import numpy as np
@@ -263,7 +261,7 @@ def plot_ate(
263261
method: str | None = None,
264262
prop_draws: int = 100,
265263
ate_draws: int = 300,
266-
) -> tuple[plt.Figure, List[plt.Axes]]:
264+
) -> tuple[plt.Figure, list[plt.Axes]]:
267265
if idata is None:
268266
idata = self.model.idata
269267
if method is None:
@@ -325,7 +323,7 @@ def make_hists(idata, i, axs, method=method):
325323
BBBBCC"""
326324

327325
fig, axs = plt.subplot_mosaic(mosaic, figsize=(20, 13))
328-
axs = [axs[k] for k in axs.keys()]
326+
axs = [axs[k] for k in axs]
329327
axs[0].axvline(
330328
0.1, linestyle="--", label="Low Extreme Propensity Scores", color="black"
331329
)
@@ -412,7 +410,7 @@ def plot_balance_ecdf(
412410
covariate: str,
413411
idata: az.InferenceData | None = None,
414412
weighting_scheme: str | None = None,
415-
) -> tuple[plt.Figure, List[plt.Axes]]:
413+
) -> tuple[plt.Figure, list[plt.Axes]]:
416414
"""
417415
Plotting function takes a single covariate and shows the
418416
differences in the ECDF between the treatment and control

causalpy/experiments/prepostnegd.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
Pretest/posttest nonequivalent group design
1616
"""
1717

18-
from typing import List
19-
2018
import arviz as az
2119
import numpy as np
2220
import pandas as pd
@@ -227,7 +225,7 @@ def summary(self, round_to: int | None = None) -> None:
227225

228226
def _bayesian_plot(
229227
self, round_to: int | None = None, **kwargs: dict
230-
) -> tuple[plt.Figure, List[plt.Axes]]:
228+
) -> tuple[plt.Figure, list[plt.Axes]]:
231229
"""Generate plot for ANOVA-like experiments with non-equivalent group designs."""
232230
fig, ax = plt.subplots(
233231
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}

causalpy/experiments/regression_discontinuity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""
1717

1818
import warnings # noqa: I001
19-
from typing import Union
2019

2120

2221
import numpy as np
@@ -88,7 +87,7 @@ def __init__(
8887
data: pd.DataFrame,
8988
formula: str,
9089
treatment_threshold: float,
91-
model: Union[PyMCModel, RegressorMixin] | None = None,
90+
model: PyMCModel | RegressorMixin | None = None,
9291
running_variable_name: str = "x",
9392
epsilon: float = 0.001,
9493
bandwidth: float = np.inf,
@@ -112,6 +111,7 @@ def __init__(
112111
warnings.warn(
113112
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
114113
UserWarning,
114+
stacklevel=2,
115115
)
116116
y, X = dmatrices(formula, filtered_data)
117117
else:
@@ -218,7 +218,7 @@ def input_validation(self) -> None:
218218
self.data = self.data.copy()
219219
self.data["treated"] = self.data["treated"].astype(bool)
220220

221-
def _is_treated(self, x: Union[np.ndarray, pd.Series]) -> np.ndarray:
221+
def _is_treated(self, x: np.ndarray | pd.Series) -> np.ndarray:
222222
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
223223
224224
.. warning::

causalpy/experiments/regression_kink.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"""
1818

1919
import warnings # noqa: I001
20-
from typing import Union
2120

2221

2322
from matplotlib import pyplot as plt
@@ -75,6 +74,7 @@ def __init__(
7574
warnings.warn(
7675
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
7776
UserWarning,
77+
stacklevel=2,
7878
)
7979
y, X = dmatrices(formula, filtered_data)
8080
else:
@@ -192,7 +192,7 @@ def _probe_kink_point(self) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
192192
mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"]
193193
return mu_kink_left, mu_kink, mu_kink_right
194194

195-
def _is_treated(self, x: Union[np.ndarray, pd.Series]) -> np.ndarray:
195+
def _is_treated(self, x: np.ndarray | pd.Series) -> np.ndarray:
196196
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
197197
return np.greater_equal(x, self.kink_point)
198198

0 commit comments

Comments
 (0)