|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +===================== |
| 4 | +Parallel Training |
| 5 | +===================== |
| 6 | +
|
| 7 | +Larger datasets require more time for training. |
| 8 | +While by default the models in HiClass are trained using a single core, |
| 9 | +it is possible to train each local classifier in parallel by leveraging the library Ray [1]_. |
| 10 | +In this example, we demonstrate how to train a hierarchical classifier in parallel, |
| 11 | +using all the cores available, on a mock dataset from Kaggle [2]_. |
| 12 | +
|
| 13 | +.. [1] https://www.ray.io/ |
| 14 | +.. [2] https://www.kaggle.com/datasets/kashnitsky/hierarchical-text-classification |
| 15 | +""" |
| 16 | +import sys |
| 17 | +from os import cpu_count |
| 18 | + |
| 19 | +import pandas as pd |
| 20 | +import requests |
| 21 | +from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer |
| 22 | +from sklearn.linear_model import LogisticRegression |
| 23 | +from sklearn.pipeline import Pipeline |
| 24 | + |
| 25 | +from hiclass import LocalClassifierPerParentNode |
| 26 | + |
| 27 | + |
| 28 | +def download(url: str, path: str) -> None: |
| 29 | + """ |
| 30 | + Download a file from the internet. |
| 31 | +
|
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + url : str |
| 35 | + The address of the file to be downloaded. |
| 36 | + path : str |
| 37 | + The path to store the downloaded file. |
| 38 | + """ |
| 39 | + response = requests.get(url) |
| 40 | + with open(path, "wb") as file: |
| 41 | + file.write(response.content) |
| 42 | + |
| 43 | + |
| 44 | +# Download training data |
| 45 | +training_data_url = "https://zenodo.org/record/6657410/files/train_40k.csv?download=1" |
| 46 | +training_data_path = "train_40k.csv" |
| 47 | +download(training_data_url, training_data_path) |
| 48 | + |
| 49 | +# Load training data into pandas dataframe |
| 50 | +training_data = pd.read_csv(training_data_path).fillna(" ") |
| 51 | + |
| 52 | +# We will use logistic regression classifiers for every parent node |
| 53 | +lr = LogisticRegression(max_iter=1000) |
| 54 | + |
| 55 | +pipeline = Pipeline( |
| 56 | + [ |
| 57 | + ("count", CountVectorizer()), |
| 58 | + ("tfidf", TfidfTransformer()), |
| 59 | + ( |
| 60 | + "lcppn", |
| 61 | + LocalClassifierPerParentNode(local_classifier=lr, n_jobs=cpu_count()), |
| 62 | + ), |
| 63 | + ] |
| 64 | +) |
| 65 | + |
| 66 | +# Select training data |
| 67 | +X_train = training_data["Title"] |
| 68 | +Y_train = training_data[["Cat1", "Cat2", "Cat3"]] |
| 69 | + |
| 70 | +# Fixes bug AttributeError: '_LoggingTee' object has no attribute 'fileno' |
| 71 | +# This only happens when building the documentation |
| 72 | +# Hence, you don't actually need it for your code to work |
| 73 | +sys.stdout.fileno = lambda: False |
| 74 | + |
| 75 | +# Now, let's train the local classifier per parent node |
| 76 | +pipeline.fit(X_train, Y_train) |
0 commit comments