Skip to content

ml-rust/treeboost

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TreeBoost

Crates.io Docs License: Apache-2.0

TreeBoost

Practical tabular ML for messy, real-world data. Fast baselines first, deep control when you need it.

TreeBoost is a Rust-first library for tabular machine learning that starts simple and scales to expert use. It is built for the reality of real-world datasets: time series, missing values, drift, mixed feature types, and noisy labels. You get a clean path from “just give me a working model” to full control over training, backends, and constraints.

At a Glance

  • AutoML and AutoTuner for fast, explainable baselines
  • Multi-label, multi-class, and regression in one API
  • Hybrid Linear+Tree mode for trend extrapolation + non-linear interactions
  • Built-in preprocessing that serializes with your model
  • GPU acceleration (WebGPU, CUDA) plus AVX-512/SVE2 CPU backends
  • Zero-copy serialization and incremental TRB updates

Why TreeBoost

Most libraries are tuned for leaderboard-style modeling. TreeBoost is built for shipping models:

  • Fast baseline in one call for beginners and teams under time pressure.
  • White-box AutoML that explains why it chose a mode and lets you iterate.
  • Upgradeable control without rewriting your data pipeline.
  • Deployment-friendly: zero-copy serialization, fast inference, and a CLI for batch jobs.

Three API Levels (Start Simple, Go Deep)

  • AutoModel — One call trains a solid baseline and produces a model you can ship. Export a config.json when you want to improve it later.
  • UniversalModel — Choose the learning mode (PureTree, LinearThenTree, RandomForest) and tune it without leaving a high-level API.
  • GBDTModel — Lowest-level API for maximum control, backend selection, and benchmarking.

You can move between these levels without changing your dataset format.

📖 See docs/API.md for complete API documentation with examples.

Recommended Workflow

  1. AutoModel for a strong baseline and a training report.
  2. Inspect the report and config to understand the model choice.
  3. Refine with UniversalModel or GBDTModel for extra accuracy, constraints, or incremental updates.

Quick Start (AutoModel)

use treeboost::{auto_train, AutoModel};

let model = auto_train(&df, "target")?;
println!("{}", model.summary());

model.save("model.rkyv")?;
model.save_config("config.json")?;

This gives you a deployable model plus a config you can tweak later.

Expected Outputs (After Training)

After training, you typically save:

  • model.rkyv for fast inference and deployment
  • config.json for reproducible retraining or fine-tuning
  • model.trb if you want incremental updates later
let model = auto_train(&df, "target")?;
model.save("model.rkyv")?;
model.save_config("config.json")?;
model.save_trb("model.trb", "initial training")?;

Example config.json (abridged):

{
  "mode": "LinearThenTree",
  "num_rounds": 120,
  "learning_rate": 0.1,
  "subsample": 0.9,
  "validation_ratio": 0.1,
  "early_stopping_rounds": 20,
  "linear_rounds": 10,
  "tree_config": {
    "max_depth": 6,
    "max_leaves": 31,
    "lambda": 1.0,
    "min_samples_leaf": 20,
    "colsample": 0.8
  },
  "linear_config": {
    "lambda": 1.0,
    "l1_ratio": 0.0,
    "shrinkage_factor": 0.3
  }
}

The actual file includes the full set of tree/linear fields so you can tweak every detail.

Inference (Simple and Fast)

use treeboost::UniversalModel;
use treeboost::dataset::DatasetLoader;

// Use either a static model (.rkyv) or incremental model (.trb)
let model = UniversalModel::load("model.rkyv")?;
// let model = UniversalModel::load_trb("model.trb")?;

let loader = DatasetLoader::new(255);
let dataset = loader.load_parquet("new_data.parquet", "target", None)?;

let predictions = model.predict(&dataset);

Quick Start (UniversalModel)

use treeboost::{UniversalConfig, UniversalModel, BoostingMode};
use treeboost::dataset::DatasetLoader;
use treeboost::loss::MseLoss;

let loader = DatasetLoader::new(255);
let dataset = loader.load_parquet("data.parquet", "target", None)?;

let config = UniversalConfig::new()
    .with_mode(BoostingMode::LinearThenTree)
    .with_num_rounds(100)
    .with_linear_rounds(10)
    .with_learning_rate(0.1);

let model = UniversalModel::train(&dataset, config, &MseLoss)?;
let predictions = model.predict(&dataset);

Quick mode selection:

Your Data Use This Mode
General tabular, categoricals BoostingMode::PureTree
Time-series, trending, needs extrapolation BoostingMode::LinearThenTree
Noisy data, want robustness BoostingMode::RandomForest

Quick Start (Multi-Label Classification)

For problems where each sample can belong to multiple independent labels (e.g., multi-tag classification, multiple disease diagnoses):

use treeboost::AutoModel;

// Train multi-label model (default: LinearThenTree mode)
let target_cols = vec!["tag_finance", "tag_tech", "tag_urgent"];
let model = AutoModel::train_multilabel(&df, &target_cols)?;

// Predictions with probabilities
let probs = model.predict_proba_multilabel(&df)?;  // Vec<Vec<f32>>

// Predictions with labels (threshold 0.5)
let labels = model.predict_labels(&df)?;  // Vec<Vec<bool>>

// Tune thresholds on validation set for better F1 scores
let mut model = AutoModel::train_multilabel(&df, &target_cols)?;
let tune_result = model.tune_thresholds(&val_df, &target_cols)?;
let labels = model.predict_labels_tuned(&df)?;  // Uses tuned thresholds

Quick Start (GBDTModel)

use treeboost::{GBDTConfig, GBDTModel};

let config = GBDTConfig::new()
    .with_num_rounds(200)
    .with_max_depth(6)
    .with_learning_rate(0.05);

let model = GBDTModel::train(&features, num_features, &targets, config, None)?;

Python (GBDTModel)

import numpy as np
from treeboost import GBDTConfig, GBDTModel

X = np.random.randn(10000, 20).astype(np.float32)
y = (X[:, 0] + X[:, 1] * 2 + np.random.randn(10000) * 0.1).astype(np.float32)

config = GBDTConfig()
config.num_rounds = 100
config.max_depth = 6
config.learning_rate = 0.1

model = GBDTModel.train(X, y, config)

What You Get

  • AutoML mode selection that evaluates probes and explains its choice.
  • Hybrid Linear+Tree architecture for trend extrapolation and interactions.
  • Built-in preprocessing: encoders, scalers, and imputers that serialize with the model.
  • Linear Trees for piecewise-linear data with far fewer trees.
  • Conformal prediction for uncertainty intervals.
  • Incremental learning via TRB format with drift detection.

Advanced Features

TreeBoost includes battle-tested capabilities for real-world deployments.

Feature Matrix

Category Capability Use Case API Entry Point
Classification Multi-Label Classification Multi-tag prediction, multi-diagnosis AutoModel::train_multilabel()
Threshold Tuning Per-label F1 optimization model.tune_thresholds()
Multi-Class Classification Single category prediction AutoModel::train() with softmax loss
Model Updates Incremental Learning (TRB format) Daily model updates, streaming data UniversalModel::update()
O(1) Append Updates Efficient model versioning save_trb_update()
Memory-Mapped I/O Large model inference MmapTrbReader (mmap feature)
Monitoring Drift Detection (PSI, KL, KS) Distribution shift alerts IncrementalDriftDetector
Drift History Tracking Long-term monitoring DriftHistory
Ensembles Multi-Seed Training Variance reduction with_ensemble_seeds()
Stacked Blending Meta-learner combination StackingStrategy::Ridge
Constraints Monotonic Constraints Domain knowledge enforcement TreeConfig::with_monotonic_constraints()
Interaction Constraints Feature interaction control TreeConfig::with_interaction_constraints()
Encoding Ordered Target Encoding High-cardinality categoricals OrderedTargetEncoder
Count-Min Sketch Filtering Rare category handling CategoryFilter
Features Time-Series (Lag/Rolling/EWMA) Panel data, forecasting LagGenerator, RollingGenerator
Cross-Sectional (Poly/Ratio) Feature engineering PolynomialGenerator, RatioGenerator
Preprocessing Incremental Scaler (Welford) Adaptive preprocessing StandardScaler::with_forget_factor()
Outlier Detection (IQR/Z-score) Robust pipelines OutlierDetector, RobustScaler
Uncertainty Split Conformal Prediction Distribution-free intervals GBDTConfig::with_conformal()

Example: Incremental Learning Workflow

use treeboost::{AutoModel, UniversalModel};
use treeboost::monitoring::IncrementalDriftDetector;
use treeboost::loss::MseLoss;

// 1. Initial training
let auto = AutoModel::train(&df, "target")?;
auto.inner().save_trb("model.trb", "Initial training")?;

// 2. Production: Load and monitor for drift
let mut model = UniversalModel::load_trb("model.trb")?;
let detector = IncrementalDriftDetector::from_dataset(&train_data);

// 3. Before updating, check for drift
let result = detector.check_update(&new_data);
if !result.has_critical_drift() {
    let report = model.update(&new_data, &MseLoss, 10)?;
    model.save_trb_update("model.trb", new_data.num_rows(), "Weekly update")?;
} else {
    eprintln!("Critical drift detected: {}", result);
}

Why These Features Matter

  • Incremental Learning: Update models in O(new_data) instead of O(total_data) - essential for daily retraining
  • Drift Detection: Catch distribution shifts before they degrade model performance
  • Ensemble Methods: Reduce variance and improve stability in noisy environments
  • Constraints: Enforce domain knowledge (e.g., "age must increase risk") for trust and interpretability
  • High-Cardinality Encoding: Handle millions of categories without memory explosion
  • Time-Series Features: Automatic lag/rolling/EWMA generation for panel data
  • Conformal Prediction: Valid uncertainty estimates regardless of data distribution

📖 For detailed API documentation with examples, see docs/API.md

Backends (Automatic by Default)

TreeBoost auto-selects the fastest backend. You can override it if needed.

use treeboost::{GBDTConfig, GBDTModel};
use treeboost::backend::BackendType;

let config = GBDTConfig::new()
    .with_backend(BackendType::Scalar);

let model = GBDTModel::train(&features, num_features, &targets, config, None)?;

Supported backends: Scalar, AVX-512, SVE2, WGPU, CUDA.

CLI

# Train a model
treeboost train --data data.csv --target price --output model.rkyv

# Predict
treeboost predict --model model.rkyv --data test.csv --output predictions.json

# Predict (.trb)
treeboost predict --model model.trb --data test.csv --output predictions.json

Run treeboost --help for full options.

Installation

cargo add treeboost
# Python bindings (requires Rust toolchain + maturin)
pip install treeboost

Feature flags: gpu, cuda, mmap, python.

Project Links

License

Apache License 2.0

About

High-performance Gradient Boosted Decision Tree engine for large-scale tabular data

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published