diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 23e422ab..e36e1e14 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -71,6 +71,7 @@ def __init__( bert: bool = False, classifier_abbreviation: str = "", tmp_dir: str = None, + warm_start: bool = False, ): """ Initialize a local hierarchical classifier. @@ -99,6 +100,9 @@ def __init__( tmp_dir : str, default=None Temporary directory to persist local classifiers that are trained. If the job needs to be restarted, it will skip the pre-trained local classifier found in the temporary directory. + warm_start : bool, default=False + When set to true, the hierarchical classifier reuses the solution of the previous call to fit, that is, + new classes can be added. """ self.local_classifier = local_classifier self.verbose = verbose @@ -108,6 +112,7 @@ def __init__( self.bert = bert self.classifier_abbreviation = classifier_abbreviation self.tmp_dir = tmp_dir + self.warm_start = warm_start def fit(self, X, y, sample_weight=None): """ @@ -155,6 +160,8 @@ def _pre_fit(self, X, y, sample_weight): else: self.sample_weight_ = None + self.warm_start_ = self.warm_start + self.y_ = make_leveled(self.y_) # Create and configure logger @@ -164,7 +171,7 @@ def _pre_fit(self, X, y, sample_weight): # which would generate the prediction a->b->c self._disambiguate() - # Create DAG from self.y_ and store to self.hierarchy_ + # Create or update DAG from self.y_ and store to self.hierarchy_ self._create_digraph() # If user passes edge_list, then export @@ -229,7 +236,7 @@ def _create_digraph(self): self._create_digraph_2d() if self.y_.ndim > 2: - # Unsuported dimension + # Unsupported dimension self.logger_.error(f"y with {self.y_.ndim} dimensions detected") raise ValueError( f"Creating graph from y with {self.y_.ndim} dimensions is not supported" @@ -250,7 +257,10 @@ def _create_digraph_1d(self): def _create_digraph_2d(self): if self.y_.ndim == 2: # Create max_levels variable - self.max_levels_ = self.y_.shape[1] + if self.warm_start_: + self.max_levels_ = max(self.max_levels_, self.y_.shape[1]) + else: + self.max_levels_ = self.y_.shape[1] rows, columns = self.y_.shape self.logger_.info(f"Creating digraph from {rows} 2D labels") for row in range(rows): @@ -296,7 +306,10 @@ def _add_artificial_root(self): self.logger_.info(f"Detected {len(roots)} roots") # Add artificial root as predecessor to root(s) detected - self.root_ = "hiclass::root" + if self.warm_start_: + roots.remove(self.root_) + else: + self.root_ = "hiclass::root" for old_root in roots: self.hierarchy_.add_edge(self.root_, old_root) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 907e61cf..2e41d331 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -53,6 +53,7 @@ def __init__( n_jobs: int = 1, bert: bool = False, tmp_dir: str = None, + warm_start: bool = False, ): """ Initialize a local classifier per level. @@ -79,6 +80,9 @@ def __init__( tmp_dir : str, default=None Temporary directory to persist local classifiers that are trained. If the job needs to be restarted, it will skip the pre-trained local classifier found in the temporary directory. + warm_start : bool, default=False + When set to true, the hierarchical classifier reuses the solution of the previous call to fit, that is, + new classes can be added. """ super().__init__( local_classifier=local_classifier, @@ -89,6 +93,7 @@ def __init__( classifier_abbreviation="LCPL", bert=bert, tmp_dir=tmp_dir, + warm_start=warm_start, ) def fit(self, X, y, sample_weight=None): @@ -115,6 +120,8 @@ def fit(self, X, y, sample_weight=None): # Execute common methods necessary before fitting super()._pre_fit(X, y, sample_weight) + # TODO: add partial_fit here if warm_start=True + # Fit local classifiers in DAG super().fit(X, y) diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 1382c72e..4b2853c1 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -48,6 +48,7 @@ def __init__( n_jobs: int = 1, bert: bool = False, tmp_dir: str = None, + warm_start: bool = False, ): """ Initialize a local classifier per node. @@ -85,6 +86,9 @@ def __init__( tmp_dir : str, default=None Temporary directory to persist local classifiers that are trained. If the job needs to be restarted, it will skip the pre-trained local classifier found in the temporary directory. + warm_start : bool, default=False + When set to true, the hierarchical classifier reuses the solution of the previous call to fit, that is, + new classes can be added. """ super().__init__( local_classifier=local_classifier, @@ -95,6 +99,7 @@ def __init__( classifier_abbreviation="LCPN", bert=bert, tmp_dir=tmp_dir, + warm_start=warm_start, ) self.binary_policy = binary_policy @@ -125,6 +130,8 @@ def fit(self, X, y, sample_weight=None): # Initialize policy self._initialize_binary_policy() + # TODO: add partial_fit here if warm_start=True + # Fit local classifiers in DAG super().fit(X, y) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 47f77475..8c507f9c 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -46,6 +46,7 @@ def __init__( n_jobs: int = 1, bert: bool = False, tmp_dir: str = None, + warm_start: bool = False, ): """ Initialize a local classifier per parent node. @@ -72,6 +73,9 @@ def __init__( tmp_dir : str, default=None Temporary directory to persist local classifiers that are trained. If the job needs to be restarted, it will skip the pre-trained local classifier found in the temporary directory. + warm_start : bool, default=False + When set to true, the hierarchical classifier reuses the solution of the previous call to fit, that is, + new classes can be added. """ super().__init__( local_classifier=local_classifier, @@ -82,6 +86,7 @@ def __init__( classifier_abbreviation="LCPPN", bert=bert, tmp_dir=tmp_dir, + warm_start=warm_start, ) def fit(self, X, y, sample_weight=None): @@ -165,6 +170,38 @@ def predict(self, X): return y + def partial_fit(self, X, y, sample_weight=None): + """ + Add new parent nodes for the local classifier per parent node. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, its dtype will be converted + to ``dtype=np.float32``. If a sparse matrix is provided, it will be + converted into a sparse ``csc_matrix``. + y : array-like of shape (n_samples, n_levels) + The target values, i.e., hierarchical class labels for classification. + sample_weight : array-like of shape (n_samples,), default=None + Array of weights that are assigned to individual samples. + If not provided, then each sample is given unit weight. + + Returns + ------- + self : object + Fitted estimator. + """ + self.warm_start_ = True + + # Execute common methods necessary before fitting + super()._pre_fit(X, y, sample_weight) + + # Fit local classifiers in DAG + super().fit(X, y) + + # Return the classifier + return self + def _predict_remaining_levels(self, X, y): for level in range(1, y.shape[1]): predecessors = set(y[:, level - 1]) @@ -183,7 +220,8 @@ def _initialize_local_classifiers(self): local_classifiers = {} nodes = self._get_parents() for node in nodes: - local_classifiers[node] = {"classifier": deepcopy(self.local_classifier_)} + if "classifier" not in self.hierarchy_.nodes[node]: + local_classifiers[node] = {"classifier": deepcopy(self.local_classifier_)} nx.set_node_attributes(self.hierarchy_, local_classifiers) def _get_parents(self): diff --git a/tests/test_HierarchicalClassifier.py b/tests/test_HierarchicalClassifier.py index 3333cf52..4cf72c0b 100644 --- a/tests/test_HierarchicalClassifier.py +++ b/tests/test_HierarchicalClassifier.py @@ -57,6 +57,21 @@ def test_create_digraph_1d(graph_1d): assert list(ground_truth.edges) == list(graph_1d.hierarchy_.edges) +def test_update_digraph_1d(graph_1d): + ground_truth = nx.DiGraph() + ground_truth.add_nodes_from(np.array(["a", "b", "c", "d", "e", "f"])) + graph_1d._create_digraph() + attributes = {} + attributes["a"] = {"trained_classifier": "yes"} + nx.set_node_attributes(graph_1d.hierarchy_, attributes) + graph_1d.y_ = np.array(["a", "b", "c", "d", "e", "f"]) + graph_1d._create_digraph_1d() + assert nx.is_isomorphic(ground_truth, graph_1d.hierarchy_) + assert list(ground_truth.nodes) == list(graph_1d.hierarchy_.nodes) + assert list(ground_truth.edges) == list(graph_1d.hierarchy_.edges) + assert graph_1d.hierarchy_.nodes["a"]["trained_classifier"] == "yes" + + @pytest.fixture def graph_1d_disguised_as_2d(): classifier = HierarchicalClassifier() @@ -82,6 +97,8 @@ def digraph_2d(): classifier.logger_ = logging.getLogger("HC") classifier.edge_list = tempfile.TemporaryFile() classifier.separator_ = "::HiClass::Separator::" + classifier.warm_start_ = True + classifier.max_levels_ = 3 return classifier @@ -93,6 +110,31 @@ def test_create_digraph_2d(digraph_2d): assert list(ground_truth.edges) == list(digraph_2d.hierarchy_.edges) +def test_update_digraph_2d(digraph_2d): + ground_truth = nx.DiGraph( + [ + ("a", "b"), + ("b", "c"), + ("d", "e"), + ("e", "f"), + ("g", "h"), + ("h", "i"), + ("i", "j"), + ] + ) + digraph_2d._create_digraph() + attributes = {} + attributes["b"] = {"trained_classifier": "yes"} + nx.set_node_attributes(digraph_2d.hierarchy_, attributes) + digraph_2d.y_ = np.array([["g", "h", "i", "j"]]) + digraph_2d._create_digraph_2d() + assert nx.is_isomorphic(ground_truth, digraph_2d.hierarchy_) + assert list(ground_truth.nodes) == list(digraph_2d.hierarchy_.nodes) + assert list(ground_truth.edges) == list(digraph_2d.hierarchy_.edges) + assert digraph_2d.hierarchy_.nodes["b"]["trained_classifier"] == "yes" + assert digraph_2d.max_levels_ == 4 + + @pytest.fixture def digraph_3d(): classifier = HierarchicalClassifier() @@ -137,6 +179,7 @@ def digraph_one_root(): classifier = HierarchicalClassifier() classifier.logger_ = logging.getLogger("HC") classifier.hierarchy_ = nx.DiGraph([("a", "b"), ("b", "c"), ("c", "d")]) + classifier.warm_start_ = False return classifier @@ -155,6 +198,7 @@ def digraph_multiple_roots(): classifier.X_ = np.array([[1, 2], [3, 4], [5, 6]]) classifier.y_ = np.array([["a", "b"], ["c", "d"], ["e", "f"]]) classifier.sample_weight_ = None + classifier.warm_start_ = False return classifier @@ -165,6 +209,17 @@ def test_add_artificial_root_multiple_roots(digraph_multiple_roots): assert "hiclass::root" == digraph_multiple_roots.root_ +def test_add_artificial_new_nodes(digraph_multiple_roots): + digraph_multiple_roots._add_artificial_root() + digraph_multiple_roots.hierarchy_.add_node("g") + digraph_multiple_roots.hierarchy_.add_node("h") + digraph_multiple_roots.warm_start_ = True + digraph_multiple_roots._add_artificial_root() + successors = list(digraph_multiple_roots.hierarchy_.successors("hiclass::root")) + assert ["a", "c", "e", "g", "h"] == successors + assert "hiclass::root" == digraph_multiple_roots.root_ + + def test_initialize_local_classifiers_2(digraph_multiple_roots): digraph_multiple_roots.local_classifier = None digraph_multiple_roots._initialize_local_classifiers() @@ -224,6 +279,7 @@ def test_fit_digraph(): def test_pre_fit_bert(): classifier = HierarchicalClassifier() classifier.logger_ = logging.getLogger("HC") + classifier.warm_start_ = False classifier.bert = True x = [[0, 1], [2, 3]] y = [["a", "b"], ["c", "d"]]