Skip to content

Add encoder #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions hiclass/HierarchicalClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import abc
import hashlib
import logging
import pickle

import networkx as nx
import numpy as np
import pickle
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.validation import _check_sample_weight

try:
Expand Down Expand Up @@ -215,7 +215,14 @@ def _disambiguate(self):
child = str(self.y_[i, j])
row.append(parent + self.separator_ + child)
new_y.append(np.asarray(row, dtype=np.str_))
self.y_ = np.array(new_y)
new_y = np.array(new_y)
flat_y = np.unique(np.append(new_y.flatten(), "hiclass::root"))
if not self.bert:
self.label_encoder_ = LabelEncoder()
self.label_encoder_.fit(flat_y)
self.y_ = np.array(
[self.label_encoder_.transform(row) for row in new_y]
)

def _create_digraph(self):
# Create DiGraph
Expand Down Expand Up @@ -255,8 +262,8 @@ def _create_digraph_2d(self):
self.logger_.info(f"Creating digraph from {rows} 2D labels")
for row in range(rows):
for column in range(columns - 1):
parent = self.y_[row, column].split(self.separator_)[-1]
child = self.y_[row, column + 1].split(self.separator_)[-1]
parent = self.y_[row, column]
child = self.y_[row, column + 1]
if parent != "" and child != "":
# Only add edge if both parent and child are not empty
self.hierarchy_.add_edge(
Expand All @@ -271,7 +278,7 @@ def _export_digraph(self):
# Add quotes to all nodes in case the text has commas
mapping = {}
for node in self.hierarchy_:
mapping[node] = '"{}"'.format(node.split(self.separator_)[-1])
mapping[node] = '"{}"'.format(node)
hierarchy = nx.relabel_nodes(self.hierarchy_, mapping, copy=True)
# Export DAG to CSV file
self.logger_.info(f"Writing edge list to file {self.edge_list}")
Expand Down Expand Up @@ -371,5 +378,5 @@ def _save_tmp(self, name, classifier):
with open(filename, "wb") as file:
pickle.dump((name, classifier), file)
self.logger_.info(
f"Stored trained model for local classifier {str(name).split(self.separator_)[-1]} in file {filename}"
f"Stored trained model for local classifier {str(name)} in file {filename}"
)
8 changes: 7 additions & 1 deletion hiclass/LocalClassifierPerLevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def predict(self, X):

y = self._convert_to_1d(y)

if hasattr(self, "label_encoder_"):
y = np.array([self.label_encoder_.inverse_transform(row) for row in y])

self._remove_separator(y)

return y
Expand Down Expand Up @@ -218,7 +221,8 @@ def _get_successors(self, level):
def _initialize_local_classifiers(self):
super()._initialize_local_classifiers()
self.local_classifiers_ = [
deepcopy(self.local_classifier_) for _ in range(self.y_.shape[1])
MulticlassClassifier(deepcopy(self.local_classifier_), strategy="ovr")
for _ in range(self.y_.shape[1])
]
self.masks_ = [None for _ in range(self.y_.shape[1])]

Expand Down Expand Up @@ -272,6 +276,8 @@ def _fit_classifier(self, level, separator):
if len(unique_y) == 1 and self.replace_classifiers:
classifier = ConstantClassifier()
if not self.bert:
self.logger_.info(X)
self.logger_.info(y)
try:
classifier.fit(X, y, sample_weight)
except TypeError:
Expand Down
9 changes: 6 additions & 3 deletions hiclass/LocalClassifierPerNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def predict(self, X):
if subset_x.shape[0] > 0:
probabilities = np.zeros((subset_x.shape[0], len(successors)))
for i, successor in enumerate(successors):
successor_name = str(successor).split(self.separator_)[-1]
successor_name = str(successor)
self.logger_.info(f"Predicting for node '{successor_name}'")
classifier = self.hierarchy_.nodes[successor]["classifier"]
positive_index = np.where(classifier.classes_ == 1)[0]
Expand All @@ -201,6 +201,9 @@ def predict(self, X):

y = self._convert_to_1d(y)

if hasattr(self, "label_encoder_"):
y = np.array([self.label_encoder_.inverse_transform(row) for row in y])

self._remove_separator(y)

return y
Expand Down Expand Up @@ -246,12 +249,12 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):
def _fit_classifier(self, node):
classifier = self.hierarchy_.nodes[node]["classifier"]
if self.tmp_dir:
md5 = hashlib.md5(node.encode("utf-8")).hexdigest()
md5 = hashlib.md5(str(node).encode("utf-8")).hexdigest()
filename = f"{self.tmp_dir}/{md5}.sav"
if exists(filename):
(_, classifier) = pickle.load(open(filename, "rb"))
self.logger_.info(
f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}"
f"Loaded trained model for local classifier {node} from file {filename}"
)
return classifier
self.logger_.info(f"Training local classifier {node}")
Expand Down
18 changes: 10 additions & 8 deletions hiclass/LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
"""

import hashlib
import networkx as nx
import numpy as np
import pickle
from copy import deepcopy
from cuml.common.device_selection import using_device_type
from cuml.multiclass import MulticlassClassifier
from os.path import exists

import networkx as nx
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_array, check_is_fitted

Expand Down Expand Up @@ -161,6 +162,9 @@ def predict(self, X):

y = self._convert_to_1d(y)

if hasattr(self, "label_encoder_"):
y = np.array([self.label_encoder_.inverse_transform(row) for row in y])

self._remove_separator(y)

return y
Expand Down Expand Up @@ -215,12 +219,12 @@ def _get_successors(self, node):
def _fit_classifier(self, node):
classifier = self.hierarchy_.nodes[node]["classifier"]
if self.tmp_dir:
md5 = hashlib.md5(node.encode("utf-8")).hexdigest()
md5 = hashlib.md5(str(node).encode("utf-8")).hexdigest()
filename = f"{self.tmp_dir}/{md5}.sav"
if exists(filename):
(_, classifier) = pickle.load(open(filename, "rb"))
self.logger_.info(
f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}"
f"Loaded trained model for local classifier {node} from file {filename}"
)
return classifier
self.logger_.info(f"Training local classifier {node}")
Expand All @@ -230,9 +234,7 @@ def _fit_classifier(self, node):
if len(unique_y) == 1 and self.replace_classifiers:
classifier = ConstantClassifier()
if not self.bert:
try:
classifier.fit(X, y, sample_weight)
except TypeError:
with using_device_type("gpu"):
classifier.fit(X, y)
else:
classifier.fit(X, y)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def explainer_data_no_root():


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_explainer_tree_lcppn(data, request):
rfc = RandomForestClassifier()
Expand Down Expand Up @@ -104,6 +105,7 @@ def test_explainer_tree_lcpn(data, request):


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_explainer_tree_lcpl(data, request):
rfc = RandomForestClassifier()
Expand All @@ -124,6 +126,7 @@ def test_explainer_tree_lcpl(data, request):


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_traversal_path_lcppn(data, request):
x_train, x_test, y_train = request.getfixturevalue(data)
Expand All @@ -146,6 +149,7 @@ def test_traversal_path_lcppn(data, request):


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_traversal_path_lcpn(data, request):
x_train, x_test, y_train = request.getfixturevalue(data)
Expand All @@ -168,6 +172,7 @@ def test_traversal_path_lcpn(data, request):


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_traversal_path_lcpl(data, request):
x_train, x_test, y_train = request.getfixturevalue(data)
Expand Down Expand Up @@ -205,6 +210,8 @@ def test_explain_with_xr(data, request, classifier):
assert isinstance(explanations, xarray.Dataset)


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize(
"classifier",
[LocalClassifierPerParentNode, LocalClassifierPerLevel, LocalClassifierPerNode],
Expand All @@ -222,6 +229,7 @@ def test_imports(classifier):


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize(
"classifier",
[LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode],
Expand Down
6 changes: 6 additions & 0 deletions tests/test_HierarchicalClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def test_disambiguate_str(ambiguous_node_str):
[["a", "a::HiClass::Separator::b"], ["b", "b::HiClass::Separator::c"]]
)
ambiguous_node_str._disambiguate()
ground_truth = np.array(
[ambiguous_node_str.label_encoder_.transform(row) for row in ground_truth]
)
assert_array_equal(ground_truth, ambiguous_node_str.y_)


Expand All @@ -37,6 +40,9 @@ def test_disambiguate_int(ambiguous_node_int):
[["1", "1::HiClass::Separator::2"], ["2", "2::HiClass::Separator::3"]]
)
ambiguous_node_int._disambiguate()
ground_truth = np.array(
[ambiguous_node_int.label_encoder_.transform(row) for row in ground_truth]
)
assert_array_equal(ground_truth, ambiguous_node_int.y_)


Expand Down
4 changes: 3 additions & 1 deletion tests/test_LocalClassifierPerLevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def test_fit_predict():
for level, classifier in enumerate(lcpl.local_classifiers_):
try:
check_is_fitted(classifier)
assert_array_equal(ground_truth[level], classifier.classes_)
assert_array_equal(
lcpl.label_encoder_.transform(ground_truth[level]), classifier.classes_
)
except NotFittedError as e:
pytest.fail(repr(e))
predictions = lcpl.predict(x)
Expand Down
18 changes: 3 additions & 15 deletions tests/test_LocalClassifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,7 @@ def test_empty_levels(empty_levels, classifier):
["2", "2.1", ""],
["3", "3.1", "3.1.2"],
]
assert list(clf.hierarchy_.nodes) == [
"1",
"2",
"2" + clf.separator_ + "2.1",
"3",
"3" + clf.separator_ + "3.1",
"3" + clf.separator_ + "3.1" + clf.separator_ + "3.1.2",
clf.root_,
]
assert list(clf.hierarchy_.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8, "hiclass::root"]
assert_array_equal(ground_truth, predictions)


Expand Down Expand Up @@ -132,12 +124,8 @@ def test_tmp_dir(classifier):
x = np.array([[1, 2], [3, 4]])
y = np.array([["a", "b"], ["c", "d"]])
clf.fit(x, y)
if isinstance(clf, LocalClassifierPerLevel):
filename = "cfcd208495d565ef66e7dff9f98764da.sav"
expected_name = 0
else:
filename = "0cc175b9c0f1b6a831c399e269772661.sav"
expected_name = "a"
filename = "cfcd208495d565ef66e7dff9f98764da.sav"
expected_name = 0
assert patcher.fs.exists(filename)
(name, classifier) = pickle.load(open(filename, "rb"))
assert expected_name == name
Expand Down