Skip to content

Commit 9fc4eb5

Browse files
committed
ruff unsafe fixes
1 parent 421295e commit 9fc4eb5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+794
-793
lines changed

python/interpret-api/interpret/newapi/explanation.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ def __init__(self, **kwargs):
2424
# TODO: Needs further discussion at design-level.
2525
def append(self, component):
2626
if not isinstance(component, Component):
27-
raise Exception(
28-
f"Can't append object of type {type(component)} to this object."
29-
)
27+
msg = f"Can't append object of type {type(component)} to this object."
28+
raise Exception(msg)
3029

3130
self.components[type(component)] = component
3231
for field_name, field_value in component.fields.items():
@@ -44,8 +43,8 @@ def __repr__(self):
4443
shape_str = f"shape: {self.shape}"
4544
fields.append(shape_str)
4645
fields.append("-" * len(shape_str))
47-
for record_key, record_val in record.items():
48-
for field_name, field_val in record_val.fields.items():
46+
for record_val in record.values():
47+
for field_name in record_val.fields:
4948
field_value = str(self.__getattr__(field_name))
5049

5150
if field_name in self._dims:
@@ -68,17 +67,14 @@ def __repr__(self):
6867
if len(field_value_str) > 60:
6968
field_value_str = field_value_str[:57] + "..."
7069
fields.append(field_value_str)
71-
fields = "\n".join(fields)
72-
73-
return fields
70+
return "\n".join(fields)
7471

7572
@classmethod
7673
def from_json(cls, json_str):
7774
from interpret.newapi.serialization import ExplanationJSONDecoder
7875

7976
d = json.loads(json_str, cls=ExplanationJSONDecoder)
80-
instance = d["content"]
81-
return instance
77+
return d["content"]
8278

8379
@classmethod
8480
def from_components(cls, components):

python/interpret-core/interpret/api/templates.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,20 @@ def visualize(self, key=None, title=None):
126126
perf=perf[key],
127127
)
128128
# pragma: no cover
129-
raise RuntimeError(f"Visual provider {provider} not supported")
129+
msg = f"Visual provider {provider} not supported"
130+
raise RuntimeError(msg)
130131
is_multiclass = is_multiclass_local_data_dict(data_dict)
131132
if is_multiclass:
132133
# Sort by predicted class' abs feature values
133134
pred_idx = data_dict["perf"]["predicted"]
134-
sort_fn = lambda x: -abs(x[pred_idx])
135+
136+
def sort_fn(x):
137+
return -abs(x[pred_idx])
135138
else:
136139
# Sort by abs feature values
137-
sort_fn = lambda x: -abs(x)
140+
def sort_fn(x):
141+
return -abs(x)
142+
138143
data_dict = sort_take(
139144
data_dict, sort_fn=sort_fn, top_n=15, reverse_results=True
140145
)
@@ -146,7 +151,7 @@ def visualize(self, key=None, title=None):
146151
title = self.feature_names[key]
147152
if feature_type == "continuous":
148153
return plot_line(data_dict, title=title)
149-
if feature_type == "nominal" or feature_type == "ordinal":
154+
if feature_type in ("nominal", "ordinal"):
150155
return plot_bar(data_dict, title=title)
151156
if feature_type == "interaction":
152157
# TODO: Generalize this out.
@@ -157,6 +162,7 @@ def visualize(self, key=None, title=None):
157162
)
158163

159164
# Handle everything else as invalid
165+
msg = f"Not supported configuration: {self.explanation_type}, {feature_type}"
160166
raise Exception( # pragma: no cover
161-
f"Not supported configuration: {self.explanation_type}, {feature_type}"
167+
msg
162168
)

python/interpret-core/interpret/blackbox/_lime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def explain_local(self, X, y=None, name=None, **kwargs):
7676
if y is not None:
7777
y = clean_dimensions(y, "y")
7878
if y.ndim != 1:
79-
raise ValueError("y must be 1 dimensional")
79+
msg = "y must be 1 dimensional"
80+
raise ValueError(msg)
8081
n_samples = len(y)
8182

8283
X, n_samples = preclean_X(
@@ -85,7 +86,8 @@ def explain_local(self, X, y=None, name=None, **kwargs):
8586

8687
predict_fn, n_classes, classes = determine_classes(self.model, X, n_samples)
8788
if n_classes >= 3:
88-
raise Exception("multiclass LIME not supported")
89+
msg = "multiclass LIME not supported"
90+
raise Exception(msg)
8991
predict_fn = unify_predict_fn(predict_fn, X, 1 if n_classes == 2 else -1)
9092

9193
X, _, _ = unify_data(

python/interpret-core/interpret/blackbox/_partialdependence.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ def _unique_grid_points(values):
1818

1919
def _percentile_grid_points(values, num_points=10):
2020
percentiles = np.linspace(0, 100, num=num_points)
21-
grid_points = np.percentile(values, percentiles)
22-
return grid_points
21+
return np.percentile(values, percentiles)
2322

2423

2524
# def _equal_spaced_grid_points(values, num_points=10):
@@ -37,11 +36,7 @@ def _gen_pdp(
3736
num_ice_samples=10,
3837
):
3938
num_uniq_vals = len(np.unique(X[:, col_idx]))
40-
if (
41-
feature_type == "nominal"
42-
or feature_type == "ordinal"
43-
or num_uniq_vals <= num_points
44-
):
39+
if feature_type in ("nominal", "ordinal") or num_uniq_vals <= num_points:
4540
grid_points = _unique_grid_points(X[:, col_idx])
4641
values, counts = np.unique(X[:, col_idx], return_counts=True)
4742
else:
@@ -109,7 +104,8 @@ def __init__(
109104

110105
predict_fn, n_classes, _ = determine_classes(model, data, n_samples)
111106
if n_classes >= 3:
112-
raise Exception("multiclass PDP not supported")
107+
msg = "multiclass PDP not supported"
108+
raise Exception(msg)
113109
predict_fn = unify_predict_fn(predict_fn, data, 1 if n_classes == 2 else -1)
114110

115111
data, self.feature_names_in_, self.feature_types_in_ = unify_data(
@@ -121,7 +117,7 @@ def __init__(
121117

122118
pdps = []
123119
unique_val_counts = np.zeros(len(self.feature_names_in_), dtype=np.int64)
124-
for col_idx, feature in enumerate(self.feature_names_in_):
120+
for col_idx, _feature in enumerate(self.feature_names_in_):
125121
feature_type = self.feature_types_in_[col_idx]
126122
pdp = _gen_pdp(
127123
data,
@@ -156,7 +152,7 @@ def explain_global(self, name=None):
156152
data_dicts = []
157153
feature_list = []
158154
density_list = []
159-
for col_idx, feature in enumerate(self.feature_names_in_):
155+
for col_idx, _feature in enumerate(self.feature_names_in_):
160156
pdp = self.pdps_[col_idx]
161157
feature_dict = {
162158
"feature_values": pdp["values"],
@@ -259,10 +255,11 @@ def visualize(self, key=None):
259255
feature_name = self.feature_names[key]
260256
if feature_type == "continuous":
261257
figure = plot_line(data_dict, title=feature_name)
262-
elif feature_type == "nominal" or feature_type == "ordinal":
258+
elif feature_type in ("nominal", "ordinal"):
263259
figure = plot_bar(data_dict, title=feature_name)
264260
else:
265-
raise Exception(f"Feature type {feature_type} is not supported.")
261+
msg = f"Feature type {feature_type} is not supported."
262+
raise Exception(msg)
266263

267264
figure["layout"]["yaxis1"].update(title="Average Response")
268265
return figure

python/interpret-core/interpret/blackbox/_sensitivity.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def __init__(
7878

7979
predict_fn, n_classes, _ = determine_classes(model, data, n_samples)
8080
if n_classes >= 3:
81-
raise Exception("multiclass MorrisSensitivity not supported")
81+
msg = "multiclass MorrisSensitivity not supported"
82+
raise Exception(msg)
8283
predict_fn = unify_predict_fn(predict_fn, data, 1 if n_classes == 2 else -1)
8384

8485
data, self.feature_names_in_, self.feature_types_in_ = unify_data(
@@ -112,7 +113,7 @@ def __init__(
112113
)
113114

114115
unique_val_counts = np.zeros(len(self.feature_names_in_), dtype=np.int64)
115-
for col_idx, feature in enumerate(self.feature_names_in_):
116+
for col_idx, _feature in enumerate(self.feature_names_in_):
116117
X_col = data[:, col_idx]
117118
unique_val_counts[col_idx] = len(np.unique(X_col))
118119

@@ -138,7 +139,7 @@ def explain_global(self, name=None):
138139
}
139140

140141
specific_data_dicts = []
141-
for feat_idx, feature_name in enumerate(self.feature_names_in_):
142+
for feat_idx, _feature_name in enumerate(self.feature_names_in_):
142143
specific_data_dict = {
143144
"type": "morris",
144145
"mu": self.mu_[feat_idx],
@@ -193,7 +194,7 @@ def __init__(
193194
selector: A dataframe whose indices correspond to explanation entries.
194195
"""
195196

196-
super(MorrisExplanation, self).__init__(
197+
super().__init__(
197198
explanation_type,
198199
internal_obj,
199200
feature_names=feature_names,
@@ -223,11 +224,10 @@ def visualize(self, key=None):
223224
data_dict = sort_take(
224225
data_dict, sort_fn=lambda x: -abs(x), top_n=15, reverse_results=True
225226
)
226-
title = "Morris Sensitivity<br>Convergence Index: {0:.3f}".format(
227+
title = "Morris Sensitivity<br>Convergence Index: {:.3f}".format(
227228
data_dict["convergence_index"]
228229
)
229-
figure = plot_horizontal_bar(data_dict, start_zero=True, title=title)
230-
return figure
230+
return plot_horizontal_bar(data_dict, start_zero=True, title=title)
231231

232232
if self.explanation_type == "global" and key is not None:
233233
multi_html_template = r"""
@@ -282,10 +282,9 @@ def visualize(self, key=None):
282282
mu_star_conf=data_dict["mu_star_conf"],
283283
)
284284

285-
html_str = multi_html_template.format(
285+
return multi_html_template.format(
286286
feature_name=self.feature_names[key], analyses=analysis
287287
)
288-
return html_str
289288

290289
return super().visualize(key)
291290

@@ -312,9 +311,8 @@ def _soft_min_max(values, soft_add=1, soft_bounds=1):
312311

313312
def _gen_problem_from_data(data, feature_names):
314313
bounds = [_soft_min_max(data[:, i]) for i, _ in enumerate(feature_names)]
315-
problem = {
314+
return {
316315
"num_vars": len(feature_names),
317316
"names": feature_names,
318317
"bounds": bounds,
319318
}
320-
return problem

python/interpret-core/interpret/blackbox/_shap.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def __init__(self, model, data, feature_names=None, feature_types=None, **kwargs
3838

3939
predict_fn, n_classes, _ = determine_classes(model, data, n_samples)
4040
if n_classes >= 3:
41-
raise Exception("multiclass SHAP not supported")
41+
msg = "multiclass SHAP not supported"
42+
raise Exception(msg)
4243
predict_fn = unify_predict_fn(predict_fn, data, 1 if n_classes == 2 else -1)
4344

4445
data, self.feature_names_in_, self.feature_types_in_ = unify_data(

0 commit comments

Comments
 (0)