diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 23e422ab..b7c0e738 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -3,13 +3,13 @@ import abc import hashlib import logging -import pickle - import networkx as nx import numpy as np +import pickle from joblib import Parallel, delayed from sklearn.base import BaseEstimator from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import _check_sample_weight try: @@ -215,7 +215,14 @@ def _disambiguate(self): child = str(self.y_[i, j]) row.append(parent + self.separator_ + child) new_y.append(np.asarray(row, dtype=np.str_)) - self.y_ = np.array(new_y) + new_y = np.array(new_y) + flat_y = np.unique(np.append(new_y.flatten(), "hiclass::root")) + if not self.bert: + self.label_encoder_ = LabelEncoder() + self.label_encoder_.fit(flat_y) + self.y_ = np.array( + [self.label_encoder_.transform(row) for row in new_y] + ) def _create_digraph(self): # Create DiGraph @@ -255,8 +262,8 @@ def _create_digraph_2d(self): self.logger_.info(f"Creating digraph from {rows} 2D labels") for row in range(rows): for column in range(columns - 1): - parent = self.y_[row, column].split(self.separator_)[-1] - child = self.y_[row, column + 1].split(self.separator_)[-1] + parent = self.y_[row, column] + child = self.y_[row, column + 1] if parent != "" and child != "": # Only add edge if both parent and child are not empty self.hierarchy_.add_edge( @@ -271,7 +278,7 @@ def _export_digraph(self): # Add quotes to all nodes in case the text has commas mapping = {} for node in self.hierarchy_: - mapping[node] = '"{}"'.format(node.split(self.separator_)[-1]) + mapping[node] = '"{}"'.format(node) hierarchy = nx.relabel_nodes(self.hierarchy_, mapping, copy=True) # Export DAG to CSV file self.logger_.info(f"Writing edge list to file {self.edge_list}") @@ -371,5 +378,5 @@ def _save_tmp(self, name, classifier): with open(filename, "wb") as file: pickle.dump((name, classifier), file) self.logger_.info( - f"Stored trained model for local classifier {str(name).split(self.separator_)[-1]} in file {filename}" + f"Stored trained model for local classifier {str(name)} in file {filename}" ) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 907e61cf..9cebadaa 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -168,6 +168,9 @@ def predict(self, X): y = self._convert_to_1d(y) + if hasattr(self, "label_encoder_"): + y = np.array([self.label_encoder_.inverse_transform(row) for row in y]) + self._remove_separator(y) return y @@ -218,7 +221,8 @@ def _get_successors(self, level): def _initialize_local_classifiers(self): super()._initialize_local_classifiers() self.local_classifiers_ = [ - deepcopy(self.local_classifier_) for _ in range(self.y_.shape[1]) + MulticlassClassifier(deepcopy(self.local_classifier_), strategy="ovr") + for _ in range(self.y_.shape[1]) ] self.masks_ = [None for _ in range(self.y_.shape[1])] @@ -272,6 +276,8 @@ def _fit_classifier(self, level, separator): if len(unique_y) == 1 and self.replace_classifiers: classifier = ConstantClassifier() if not self.bert: + self.logger_.info(X) + self.logger_.info(y) try: classifier.fit(X, y, sample_weight) except TypeError: diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 1382c72e..42f7a648 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -182,7 +182,7 @@ def predict(self, X): if subset_x.shape[0] > 0: probabilities = np.zeros((subset_x.shape[0], len(successors))) for i, successor in enumerate(successors): - successor_name = str(successor).split(self.separator_)[-1] + successor_name = str(successor) self.logger_.info(f"Predicting for node '{successor_name}'") classifier = self.hierarchy_.nodes[successor]["classifier"] positive_index = np.where(classifier.classes_ == 1)[0] @@ -201,6 +201,9 @@ def predict(self, X): y = self._convert_to_1d(y) + if hasattr(self, "label_encoder_"): + y = np.array([self.label_encoder_.inverse_transform(row) for row in y]) + self._remove_separator(y) return y @@ -246,12 +249,12 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False): def _fit_classifier(self, node): classifier = self.hierarchy_.nodes[node]["classifier"] if self.tmp_dir: - md5 = hashlib.md5(node.encode("utf-8")).hexdigest() + md5 = hashlib.md5(str(node).encode("utf-8")).hexdigest() filename = f"{self.tmp_dir}/{md5}.sav" if exists(filename): (_, classifier) = pickle.load(open(filename, "rb")) self.logger_.info( - f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" + f"Loaded trained model for local classifier {node} from file {filename}" ) return classifier self.logger_.info(f"Training local classifier {node}") diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 47f77475..089ce0dc 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -5,12 +5,13 @@ """ import hashlib +import networkx as nx +import numpy as np import pickle from copy import deepcopy +from cuml.common.device_selection import using_device_type +from cuml.multiclass import MulticlassClassifier from os.path import exists - -import networkx as nx -import numpy as np from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array, check_is_fitted @@ -161,6 +162,9 @@ def predict(self, X): y = self._convert_to_1d(y) + if hasattr(self, "label_encoder_"): + y = np.array([self.label_encoder_.inverse_transform(row) for row in y]) + self._remove_separator(y) return y @@ -215,12 +219,12 @@ def _get_successors(self, node): def _fit_classifier(self, node): classifier = self.hierarchy_.nodes[node]["classifier"] if self.tmp_dir: - md5 = hashlib.md5(node.encode("utf-8")).hexdigest() + md5 = hashlib.md5(str(node).encode("utf-8")).hexdigest() filename = f"{self.tmp_dir}/{md5}.sav" if exists(filename): (_, classifier) = pickle.load(open(filename, "rb")) self.logger_.info( - f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" + f"Loaded trained model for local classifier {node} from file {filename}" ) return classifier self.logger_.info(f"Training local classifier {node}") @@ -230,9 +234,7 @@ def _fit_classifier(self, node): if len(unique_y) == 1 and self.replace_classifiers: classifier = ConstantClassifier() if not self.bert: - try: - classifier.fit(X, y, sample_weight) - except TypeError: + with using_device_type("gpu"): classifier.fit(X, y) else: classifier.fit(X, y) diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py index 303216f6..a466e69b 100644 --- a/tests/test_Explainer.py +++ b/tests/test_Explainer.py @@ -52,6 +52,7 @@ def explainer_data_no_root(): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_explainer_tree_lcppn(data, request): rfc = RandomForestClassifier() @@ -104,6 +105,7 @@ def test_explainer_tree_lcpn(data, request): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_explainer_tree_lcpl(data, request): rfc = RandomForestClassifier() @@ -124,6 +126,7 @@ def test_explainer_tree_lcpl(data, request): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_traversal_path_lcppn(data, request): x_train, x_test, y_train = request.getfixturevalue(data) @@ -146,6 +149,7 @@ def test_traversal_path_lcppn(data, request): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_traversal_path_lcpn(data, request): x_train, x_test, y_train = request.getfixturevalue(data) @@ -168,6 +172,7 @@ def test_traversal_path_lcpn(data, request): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_traversal_path_lcpl(data, request): x_train, x_test, y_train = request.getfixturevalue(data) @@ -205,6 +210,8 @@ def test_explain_with_xr(data, request, classifier): assert isinstance(explanations, xarray.Dataset) +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize( "classifier", [LocalClassifierPerParentNode, LocalClassifierPerLevel, LocalClassifierPerNode], @@ -222,6 +229,7 @@ def test_imports(classifier): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize( "classifier", [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], diff --git a/tests/test_HierarchicalClassifier.py b/tests/test_HierarchicalClassifier.py index 3333cf52..cb608f05 100644 --- a/tests/test_HierarchicalClassifier.py +++ b/tests/test_HierarchicalClassifier.py @@ -22,6 +22,9 @@ def test_disambiguate_str(ambiguous_node_str): [["a", "a::HiClass::Separator::b"], ["b", "b::HiClass::Separator::c"]] ) ambiguous_node_str._disambiguate() + ground_truth = np.array( + [ambiguous_node_str.label_encoder_.transform(row) for row in ground_truth] + ) assert_array_equal(ground_truth, ambiguous_node_str.y_) @@ -37,6 +40,9 @@ def test_disambiguate_int(ambiguous_node_int): [["1", "1::HiClass::Separator::2"], ["2", "2::HiClass::Separator::3"]] ) ambiguous_node_int._disambiguate() + ground_truth = np.array( + [ambiguous_node_int.label_encoder_.transform(row) for row in ground_truth] + ) assert_array_equal(ground_truth, ambiguous_node_int.y_) diff --git a/tests/test_LocalClassifierPerLevel.py b/tests/test_LocalClassifierPerLevel.py index 27312f85..ee39e90b 100644 --- a/tests/test_LocalClassifierPerLevel.py +++ b/tests/test_LocalClassifierPerLevel.py @@ -128,7 +128,9 @@ def test_fit_predict(): for level, classifier in enumerate(lcpl.local_classifiers_): try: check_is_fitted(classifier) - assert_array_equal(ground_truth[level], classifier.classes_) + assert_array_equal( + lcpl.label_encoder_.transform(ground_truth[level]), classifier.classes_ + ) except NotFittedError as e: pytest.fail(repr(e)) predictions = lcpl.predict(x) diff --git a/tests/test_LocalClassifiers.py b/tests/test_LocalClassifiers.py index abd7bddf..b47e88b4 100644 --- a/tests/test_LocalClassifiers.py +++ b/tests/test_LocalClassifiers.py @@ -63,15 +63,7 @@ def test_empty_levels(empty_levels, classifier): ["2", "2.1", ""], ["3", "3.1", "3.1.2"], ] - assert list(clf.hierarchy_.nodes) == [ - "1", - "2", - "2" + clf.separator_ + "2.1", - "3", - "3" + clf.separator_ + "3.1", - "3" + clf.separator_ + "3.1" + clf.separator_ + "3.1.2", - clf.root_, - ] + assert list(clf.hierarchy_.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8, "hiclass::root"] assert_array_equal(ground_truth, predictions) @@ -132,12 +124,8 @@ def test_tmp_dir(classifier): x = np.array([[1, 2], [3, 4]]) y = np.array([["a", "b"], ["c", "d"]]) clf.fit(x, y) - if isinstance(clf, LocalClassifierPerLevel): - filename = "cfcd208495d565ef66e7dff9f98764da.sav" - expected_name = 0 - else: - filename = "0cc175b9c0f1b6a831c399e269772661.sav" - expected_name = "a" + filename = "cfcd208495d565ef66e7dff9f98764da.sav" + expected_name = 0 assert patcher.fs.exists(filename) (name, classifier) = pickle.load(open(filename, "rb")) assert expected_name == name