Skip to content

galilai-group/stable-pretraining

Repository files navigation

stable-pretraining

Documentation Benchmarks Test Status PyTorch Ruff License WandB

AI is moving beyond labels. Today's models learn through self-supervision and multimodal alignment, extracting knowledge from raw data to build general-purpose representations that work across tasks. These foundation models are then deployed at scale, often after finetuning, to solve tasks in zero or few shot.

stable-pretraining is a PyTorch framework built on top of Lightning for this new paradigm. What sets us apart is real-time visibility into training quality through extensive logging and monitoring. Our callback ecosystem (OnlineProbe, OnlineKNN, RankMe, and many more) provides insights into feature collapse, training dynamics, and downstream performance. Data flows as dictionaries through model components, metrics, and callbacks, making any intermediate value accessible and debuggable. With stable-pretraining: track everything, debug faster, iterate sooner.

Join our Discord: https://discord.gg/8M6hT39X

How?

To reach flexibility, scalability and stability, we rely on battle-tested third party libraries: PyTorch, Lightning, HuggingFace, TorchMetrics amongst a few others. Those dependencies allow us to focus on assembling everything into a powerful ML framework. stable-pretraining adopts a flexible and modular design for seamless integration of components from external libraries, including architectures, loss functions, evaluation metrics, and augmentations.

Quick setup

# Clone the repository
git clone https://github.com/galilai-group/stable-pretraining.git

# Install the framework
cd stable-pretraining
pip install -e .

For advanced installation options, see Installation below.

Core Structure

stable-pretraining simplifies complex ML workflows into 4 intuitive components:

1 - Data

Your dataset must follow a dictionary-structured format where each sample is a dictionary with named fields (e.g., {"image": ..., "label": ...}). This ensures consistent behavior across all components. You have multiple options for creating datasets:

  • HuggingFace datasets (if available on the Hub):
import stable_pretraining as spt
train_dataset = spt.data.HFDataset(
    path="frgfm/imagenette",
    name="160px",
    split="train",
    transform=train_transform,
)
  • From PyTorch datasets:
train_dataset = spt.data.FromTorchDataset(
    torchvision_dataset,
    names=["image", "label"],  # Map tuple outputs to dictionary keys
    transform=train_transform,
)
  • Custom datasets: Any dataset that returns dictionaries
datamodule = spt.data.DataModule(train=train_dataloader, val=val_dataloader)

2 - Module

The key differentiator from PyTorch Lightning - you only define the forward function, not training_step! This unified approach computes losses and generates useful quantities that can be retrieved for monitoring and analysis:

# Use the pre-built forward functions from stable_pretraining
from stable_pretraining import forward

# Simply use the appropriate forward for your method
module = spt.Module(
    backbone=backbone,
    projector=projector,
    forward=forward.simclr_forward,  # Or byol_forward, vicreg_forward, etc.
    simclr_loss=spt.losses.NTXEntLoss(temperature=0.5),
    optim={
        "optimizer": {"type": "Adam", "lr": 0.001},
        "scheduler": {"type": "CosineAnnealingLR"},
        "interval": "epoch"
    }
)

Or define your own custom forward:

def forward(self, batch, stage):
    out = {}

    if isinstance(batch, list):
        # Multi-view training - batch is a list of view dicts
        embeddings = [self.backbone(view["image"]) for view in batch]
        out["embedding"] = torch.cat(embeddings, dim=0)

        if self.training:
            projections = [self.projector(emb) for emb in embeddings]
            out["loss"] = self.simclr_loss(projections[0], projections[1])
    else:
        # Single-view validation
        out["embedding"] = self.backbone(batch["image"])

    return out

Key points:

  • The forward method defines both the loss and any quantities to monitor
  • No need to override training_step, validation_step, etc.
  • Return a dictionary with a "loss" key for training
  • All model components are passed as kwargs to spt.Module

3 - Callbacks

Monitor and evaluate your models in real-time during training. Callbacks are key ingredients of stable-pretraining, providing rich insights without interrupting your training flow.

Evaluation & Monitoring

Callback Description
OnlineProbe Trains a lightweight linear probe on frozen representations to track downstream task accuracy in real-time. Maintains its own optimizer and training loop.
OnlineKNN Non-parametric k-nearest neighbors evaluator using a rolling queue of cached embeddings. Zero training cost.
RankMe Tracks effective rank of feature representations via singular values. A rank drop signals dimensional collapse.
LiDAR Linear Discriminant Analysis Rank over surrogate classes of augmented views.
CLIPZeroShot Zero-shot classification for CLIP-style models. Compares image embeddings against text-encoded class names.
ImageRetrieval Image retrieval evaluator following the DINO protocol with query/gallery splits.
LatentViz Online 2D visualization of the latent space. Learns a neighborhood-preserving projection and periodically plots it.
EpochMilestones Early-stops training if a metric fails to reach a threshold by a given epoch.

Training Utilities

Callback Description
TeacherStudentCallback Auto-discovers TeacherStudentWrapper instances and performs EMA teacher updates at configurable frequency.
WeightDecayUpdater Updates weight decay on a per-batch schedule (constant, linear, cosine, or exponential).
EmbeddingCache Hooks into named submodules to cache intermediate embeddings for downstream use.

Checkpointing & Export

Callback Description
SklearnCheckpoint Saves and restores scikit-learn models (probes, classifiers) inside Lightning checkpoints.
WandbCheckpoint Uploads checkpoints to Weights & Biases as artifacts with run-resume support.
StrictCheckpointCallback Controls strict/non-strict checkpoint loading with detailed mismatch reporting.
HuggingFaceCheckpointCallback Exports HuggingFace-compatible checkpoints for any PreTrainedModel submodule (zero-knowledge reload).

System & Logging

Callback Description
LoggingCallback Displays validation metrics in a color-coded formatted table after each epoch.
ModuleSummary Logs detailed parameter statistics (trainable, frozen, per-layer) at the start of training.
TrainerInfo Links trainer to DataModule and logs trainer configuration.
SLURMInfo Extracts and logs SLURM environment information (job ID, partition, resources).
EnvironmentDumpCallback Dumps Python version, CUDA info, installed packages, git state, and env vars to environment.json for exact reproducibility.
LogUnusedParametersOnce Reports parameters that receive no gradient after the first backward pass. Useful for catching wiring bugs.
CleanUpCallback Removes selected training artifacts (SLURM logs, Hydra files, checkpoints, etc.) after successful training. Keeps everything on failure for debugging.
ModuleRegistryCallback Registers the module for global logging access. Enables spt.log() and spt.log_dict() from anywhere.

Intelligent Queue System

Callbacks that need rolling feature stores (OnlineKNN, RankMe, LiDAR, LatentViz) share memory through an automatic queue management system. If two callbacks monitor the same key with different queue lengths, a single queue is allocated at the maximum length and shared, eliminating redundant computation.

Why callbacks matter: Get real-time feedback on representation quality, catch issues like collapse early, and track multiple metrics simultaneously. For detailed usage and practical considerations, see the Callback guide.

Example:

# Monitor SSL representations with a linear classifier
linear_probe = spt.callbacks.OnlineProbe(
    module,
    name="linear_probe",
    input="embedding",
    target="label",
    probe=torch.nn.Linear(512, 10),
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics={
        "top1": torchmetrics.classification.MulticlassAccuracy(10),
        "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5),
    },
)

# Track representation quality with KNN evaluation
knn_probe = spt.callbacks.OnlineKNN(
    name="knn_probe",
    input="embedding",
    target="label",
    queue_length=20000,
    k=10,
)

4 - Trainer

Orchestrate everything together with PyTorch Lightning's Trainer:

trainer = pl.Trainer(
    max_epochs=10,
    num_sanity_val_steps=1,
    callbacks=[linear_probe, knn_probe, rankme],  # Your monitoring callbacks
    precision="16-mixed",
    logger=False,
    enable_checkpointing=False,
)
manager = spt.Manager(trainer=trainer, module=module, data=data)
manager()

Once configured, the Manager connects all components and handles the training loop with precise logging and monitoring (optional).

Global Configuration

Instead of scattering options across environment variables, callback constructors, and factory functions, stable-pretraining provides a single entry-point to configure library-wide behavior:

import stable_pretraining as spt

spt.set(
    verbose="WARNING",                          # Global log level (also controls callback verbosity)
    progress_bar="rich",                        # "auto", "rich", "simple", or "none"
    cleanup={"checkpoints": False, "slurm": False},  # What CleanUpCallback removes
    log_rank="all",                             # Which distributed rank(s) may log (default: 0)
    default_callbacks={"env_dump": False},       # Toggle individual default callbacks on/off
)

# Inspect the current configuration
print(spt.get_config())
Setting Type Default Description
verbose str or int "INFO" Loguru log level. Accepts "DEBUG", "INFO", "WARNING", etc., or Python logging ints (10, 20, 30). Also controls the verbose flag on all callbacks when left at their default.
progress_bar str "auto" Progress bar style. "auto" picks "rich" for TTYs and "simple" for non-interactive environments. "none" disables it.
cleanup dict keeps checkpoints & logs Controls which artifact categories CleanUpCallback removes after training. Keys: "checkpoints", "logs", "hydra", "slurm", "env_dump", "callback_artifacts". Values are bools (True = keep, False = delete).
log_rank int or "all" 0 Which distributed rank(s) may produce log output.
default_callbacks dict all enabled Toggle individual default callbacks: "progress_bar", "registry", "logging", "env_dump", "trainer_info", "sklearn_checkpoint", "wandb_checkpoint", "module_summary", "slurm_info", "unused_params", "hf_checkpoint".
default_loggers dict all enabled Toggle default loggers: "registry" (SQLite run registry + per-step CSV logger, added as a pair).
cache_dir str or None None (or SPT_CACHE_DIR env var) Root directory for all training outputs. See Output Directory below.
requeue_checkpoint bool True Auto-add a last.ckpt checkpoint every epoch for SLURM requeue. Set to False to save time/disk when preemption is not a concern. Only applies when cache_dir is set.

Settings apply immediately and persist for the process lifetime. spt.set() can be called multiple times; only the settings you pass are updated.

Output Directory (cache_dir)

By default, Lightning and Hydra scatter training outputs (checkpoints, logs, wandb data) based on the current working directory or Hydra's run.dir. This causes collisions when multiple sweep jobs start at the same time and resolve to the same path.

stable-pretraining solves this with a centralized cache_dir. When set, every run gets its own unique directory and all outputs are routed there automatically:

{cache_dir}/runs/{YYYYMMDD}/{HHMMSS}/{run_id}/
├── checkpoints/last.ckpt
├── wandb_resume.json
├── run_meta.json
├── environment.json
└── ...

Enabling cache_dir

import stable_pretraining as spt

# Option 1: in Python
spt.set(cache_dir="~/.cache/stable_pretraining")

# Option 2: via environment variable (e.g. in ~/.bashrc)
# export SPT_CACHE_DIR=~/.cache/stable_pretraining

When cache_dir is set, the Manager:

  1. Creates a unique run directory under cache_dir/runs/.
  2. Sets the Trainer's default_root_dir to that directory (before instantiation).
  3. Redirects all ModelCheckpoint callbacks to run_dir/checkpoints/ (preserving their filename, monitor, and other settings).
  4. Adds a requeue checkpoint (last.ckpt, saved every epoch) for seamless SLURM preemption recovery. You never need to add one yourself.
  5. Routes all callback outputs (environment dumps, latent visualizations, HuggingFace exports, etc.) there.

If preemption is not a concern and you want to skip the requeue checkpoint overhead:

spt.set(cache_dir="/scratch/runs", requeue_checkpoint=False)

When cache_dir is not set (None, the default), the library behaves exactly as before.

How the run ID is generated

The run ID is chosen to be deterministic across all ranks of the same job, so multi-GPU training always agrees on a single directory:

Environment Run ID Example
SLURM SLURM_JOB_ID 99999
SLURM array job SLURM_JOB_ID_SLURM_ARRAY_TASK_ID 99999_3
torchrun TORCHELASTIC_RUN_ID abc123
Local / other Random UUID (12 hex chars) a1b2c3d4e5f6

ckpt_path is for loading only

ckpt_path and cache_dir serve different purposes:

  • ckpt_path = where to load weights from (one-time, read-only).
  • cache_dir = where to save everything going forward.
manager = spt.Manager(
    trainer=trainer_cfg,
    module=module,
    data=data,
    ckpt_path="/old/run/pretrained.ckpt",  # Load from here once
)
# New checkpoints, logs, wandb data → cache_dir/runs/.../

If you don't pass ckpt_path, the system checks run_dir/checkpoints/last.ckpt automatically. This means SLURM requeue works without any user configuration: the job is preempted, restarted with the same SLURM_JOB_ID, finds its previous run directory, and resumes from the last checkpoint.

Hydra compatibility

When cache_dir is active, Hydra's run.dir, sweep.dir, and job.chdir settings are ignored for trainer outputs (a warning is logged). Hydra still manages its own .hydra/ config dumps as usual. Note that SLURM .out/.err files are created by the scheduler before Python starts and cannot be redirected into the run directory.

Run Registry

When cache_dir is set, stable-pretraining automatically maintains a local SQLite registry that indexes every run. Think of it as a local, offline, instant-query alternative to the wandb dashboard — designed for large sweeps on HPC clusters.

There is nothing to configure — if cache_dir is set, the registry is active:

import stable_pretraining as spt

spt.set(cache_dir="/scratch/runs")
# That's it. Every run is indexed in /scratch/runs/registry.db

What gets stored

The registry captures three categories of data per run, all automatically:

  • Config / hparams — the full Hydra config (trainer, module, data) is flattened into dot-separated keys (e.g. module.optim.optimizer.lr, trainer.max_epochs) and stored as both config and hparams. This works the same way as wandb's config: the Manager flattens the Hydra DictConfig and injects it into module.save_hyperparameters() before trainer.fit(), so Lightning's built-in _log_hyperparams sends it to all loggers (wandb, CSV, TensorBoard, and the registry) automatically.
  • Summary — every self.log() call in your LightningModule accumulates into a wandb-style summary dict (last value per metric key). At the end of training, the final summary (e.g. {"val_acc": 0.85, "train_loss": 0.12}) is written to the database.
  • Metadata — run ID, status (running/completed/failed), tags, notes, run_dir path, and best checkpoint path.

Grouping with tags

All grouping is done through tags — a flat list of strings attached to each run. There is no separate "project" or "group" concept; cache_dir already acts as the project (one database file), and tags handle everything else.

For SLURM array jobs, a "sweep:<SLURM_ARRAY_JOB_ID>" tag is automatically added so that all tasks in the same array are queryable as a group. You can add your own tags in YAML:

logger:
  - _target_: stable_pretraining.registry.RegistryLogger
    tags: [resnet50, simclr, ablation-v2]
    notes: "Testing higher learning rates"

Querying runs

Use spt.open_registry() from a notebook or script to query across all your runs:

import stable_pretraining as spt

spt.set(cache_dir="/scratch/runs")
reg = spt.open_registry()

# All completed runs from a SLURM array sweep
best = reg.query(tag="sweep:12345", status="completed", sort_by="summary.val_acc", limit=5)
for r in best:
    print(f"{r.run_id}: val_acc={r.summary['val_acc']:.3f}  lr={r.hparams['module.optim.optimizer.lr']}")

# Load the best checkpoint directly
import torch
ckpt = torch.load(best[0].checkpoint_path)

# Filter by any Hydra config key (deeply nested keys work)
lars_runs = reg.query(hparams={"module.optim.optimizer.type": "LARS"})

# All resnet runs as a pandas DataFrame
df = reg.to_dataframe(tag="resnet50")
# Columns include flattened hparams and summary:
#   run_id, status, tags, notes, checkpoint_path,
#   hparams.module.optim.optimizer.lr, hparams.trainer.max_epochs,
#   summary.val_acc, summary.train_loss, ...

# Quick analysis
df[["run_id", "hparams.module.optim.optimizer.lr", "summary.val_acc"]].sort_values(
    "summary.val_acc", ascending=False
).head(10)

Concurrency and SLURM requeue

The registry uses SQLite in WAL mode with exponential-backoff retries, so thousands of concurrent SLURM jobs can safely write to the same database file. On SLURM requeue, the run ID is deterministic (derived from SLURM_JOB_ID), so the requeued job reconnects to the same registry row and resumes seamlessly — the same mechanism used for wandb run resumption and checkpoint recovery.

CLI

The spt registry command lets you query runs from the terminal:

# List all runs
spt registry ls

# Filter by tag or status
spt registry ls --tag resnet50 --status completed

# Top 5 runs by a metric (use --asc for losses)
spt registry best val_acc
spt registry best train_loss --asc -n 10

# Show full details for a run (config, summary, tags)
spt registry show <run_id>

# Export to CSV or Parquet
spt registry export sweep_results.csv --tag sweep:12345

By default the CLI uses ~/.cache/stable-pretraining/registry.db (or $SPT_CACHE_DIR/registry.db if set). Pass --db /path/to/registry.db to query a different database.

Disabling the registry

spt.set(default_loggers={"registry": False})

Built-in Methods

stable-pretraining ships with ready-to-use forward functions and matching loss functions for popular self-supervised learning methods:

Method Forward Loss Description
Supervised forward.supervised_forward any Standard supervised training with labels
SimCLR forward.simclr_forward NTXEntLoss Contrastive learning with 2 augmented views
BYOL forward.byol_forward BYOLLoss Momentum-based self-distillation without negatives
VICReg forward.vicreg_forward VICRegLoss Variance-invariance-covariance regularization
Barlow Twins forward.barlow_twins_forward BarlowTwinsLoss Cross-correlation matrix alignment to identity
SwAV forward.swav_forward sinkhorn_knopp Online clustering with Sinkhorn-Knopp normalization
NNCLR forward.nnclr_forward NTXEntLoss Nearest-neighbor contrastive learning
DINO forward.dino_forward DINOv1Loss Self-distillation with multi-crop and centering
DINOv2 forward.dinov2_forward DINOv2Loss DINO + iBOT masked patch prediction

Backbones

Load architectures from popular libraries or use built-in components:

# From torchvision
backbone = spt.backbone.from_torchvision("resnet50")

# From timm (thousands of pretrained models)
backbone = spt.backbone.from_timm("vit_base_patch16_224")

# From HuggingFace
backbone = spt.backbone.vit_hf("google/vit-base-patch16-224")

Additional building blocks: MLP, ConvMixer, Resnet9, TeacherStudentWrapper (EMA), MAEDecoder, MaskedEncoder, FlexibleTransformer, PatchMasking, IJEPAMasking, MultiBlockMasking, LinearProbe, AutoLinearClassifier, AutoTuneMLP, and more.

Optimizers & Schedulers

Component Description
LARS Layer-wise Adaptive Rate Scaling - the standard optimizer for SSL
LinearWarmupCosineAnnealing Linear warmup followed by cosine decay
LinearWarmupCyclicAnnealing Linear warmup followed by cyclic cosine decay
CosineDecayer Pure cosine decay schedule
create_optimizer / create_scheduler Factory functions that accept string names, dicts, or partial objects
module = spt.Module(
    ...,
    optim={
        "optimizer": {"type": "LARS", "lr": 5, "weight_decay": 1e-6},
        "scheduler": {"type": "LinearWarmupCosineAnnealing"},
        "interval": "epoch",
    },
)

Complete Example

SimCLR on CIFAR-10

This example demonstrates the key features of stable-pretraining: dictionary-structured data, unified forward function, and rich monitoring through callbacks.

import lightning as pl
import torch
import torchmetrics
import torchvision
from torch import nn
from lightning.pytorch.loggers import WandbLogger

import stable_pretraining as spt
from stable_pretraining import forward
from stable_pretraining.data import transforms

# Define augmentations for SimCLR (creates 2 views of each image)
simclr_transform = transforms.MultiViewTransform(
    [
        transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((32, 32), scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToImage(**spt.data.static.CIFAR10),
        ),
        # Second view with slightly different augmentations
        transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((32, 32), scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomSolarize(threshold=0.5, p=0.2),
            transforms.ToImage(**spt.data.static.CIFAR10),
        ),
    ]
)

# Load CIFAR-10 and wrap in dictionary format
cifar_train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True)
cifar_val = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)

train_dataset = spt.data.FromTorchDataset(
    cifar_train,
    names=["image", "label"],  # Convert tuple to dictionary
    transform=simclr_transform,
)

val_dataset = spt.data.FromTorchDataset(
    cifar_val,
    names=["image", "label"],
    transform=transforms.Compose(
        transforms.RGB(),
        transforms.Resize((32, 32)),
        transforms.ToImage(**spt.data.static.CIFAR10),
    ),
)

# Create dataloaders - MultiViewTransform handles the view creation
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=256,
    num_workers=8,
    drop_last=True,
    shuffle=True,  # Simple shuffle, no RepeatedRandomSampler needed
)

val_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=256,
    num_workers=10,
)

data = spt.data.DataModule(train=train_dataloader, val=val_dataloader)

# Build model components
backbone = spt.backbone.from_torchvision("resnet18", low_resolution=True)
backbone.fc = torch.nn.Identity()  # Remove classification head

projector = nn.Sequential(
    nn.Linear(512, 2048),
    nn.BatchNorm1d(2048),
    nn.ReLU(inplace=True),
    nn.Linear(2048, 2048),
    nn.BatchNorm1d(2048),
    nn.ReLU(inplace=True),
    nn.Linear(2048, 256),
)

# Create the module using the built-in SimCLR forward function
module = spt.Module(
    backbone=backbone,
    projector=projector,
    forward=forward.simclr_forward,  # Use the built-in forward function
    simclr_loss=spt.losses.NTXEntLoss(temperature=0.5),
    optim={
        "optimizer": {"type": "LARS", "lr": 5, "weight_decay": 1e-6},
        "scheduler": {"type": "LinearWarmupCosineAnnealing"},
        "interval": "epoch",
    },
)

# Add callbacks for monitoring performance during training
linear_probe = spt.callbacks.OnlineProbe(
    module,
    name="linear_probe",
    input="embedding",
    target="label",
    probe=torch.nn.Linear(512, 10),
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics={
        "top1": torchmetrics.classification.MulticlassAccuracy(10),
        "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5),
    },
)

knn_probe = spt.callbacks.OnlineKNN(
    name="knn_probe",
    input="embedding",
    target="label",
    queue_length=20000,
    metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(10)},
    input_dim=512,
    k=10,
)

# Configure training
trainer = pl.Trainer(
    max_epochs=1000,
    callbacks=[knn_probe, linear_probe],  # Monitor SSL quality in real-time
    precision="16-mixed",
    logger=WandbLogger(project="cifar10-simclr"),
)

# Launch training
manager = spt.Manager(trainer=trainer, module=module, data=data)
manager()

Quick Start with spt CLI

The spt command launches training from YAML configuration files using Hydra.

Note: spt requires YAML configs. If you have Python-based configs, you can:

  • Convert them to YAML format where each component uses _target_ to specify the importable class/function
  • See examples/simclr_cifar10_config.yaml for the structure and syntax

Local Training

# Run with a config file
spt examples/simclr_cifar10_config.yaml

# With parameter overrides
spt examples/simclr_cifar10_config.yaml trainer.max_epochs=50 module.optim.lr=0.01

# Run from any directory - supports absolute and relative paths
spt ../configs/my_config.yaml
spt /path/to/config.yaml

SLURM Cluster Training

For training on SLURM clusters, use the -m flag to enable multirun mode:

# Use the provided SLURM template (customize partition/QOS in the file)
spt examples/simclr_cifar10_slurm.yaml -m

# Override SLURM parameters via command line
spt examples/simclr_cifar10_slurm.yaml -m \
    hydra.launcher.partition=gpu \
    hydra.launcher.qos=normal \
    hydra.launcher.timeout_min=720

The SLURM template (examples/simclr_cifar10_slurm.yaml) includes placeholders for cluster-specific settings. Either modify the file directly or override values via command line.

Installation

The library is not yet available on PyPI. You can install it from the source code, as follows.

  1. conda (optional)

    First use your favorite environment manager and install your favorite pytorch version, we provide an example with conda

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
    bash Miniconda3-latest-Linux-x86_64.sh
    

    follow installation instructions... once completed, create your environment

    conda create -n my_env python=3.11
    

    with your environment name (here my_env) and your favorite Python version (here, 3.11). Once completed, make sure to activate your environment (conda activate my_env) before proceeding to the next steps!

  1. Pytorch and our library (we recommend using uv for quicker package management):

    pip3 install uv
    uv pip install torch torchvision torchaudio
    uv pip install -e .  # Core dependencies only

    For optional features (vision models, experiment tracking, cluster support, etc.):

    uv pip install -e ".[vision,tracking]"  # Example: add vision models and wandb
    uv pip install -e ".[all]"  # Or install all optional dependencies

    See pyproject.toml for available dependency groups (vision, tracking, cluster, visualization, datasets, extras, dev, doc).

    If you do not want to use uv, simply remove it from the above commands.

  2. API login (optional)

    wandb login
    huggingface-cli login
    
  3. LaTeX support in Matplotlib (optional)

    Click to expand setup instructions

    Install Computer Modern fonts:

    mkdir -p ~/.local/share/fonts
    cp assets/cm-unicode-0.7.0\ 2/*.ttf ~/.local/share/fonts/
    fc-cache -f -v
    # verify: fc-list | grep cmu

    Clear matplotlib font cache:

    python -c "import shutil, matplotlib; shutil.rmtree(matplotlib.get_cachedir())"

    Install TeX Live (minimal, no sudo):

    cd /tmp
    wget https://mirror.ctan.org/systems/texlive/tlnet/install-tl-unx.tar.gz
    tar xzf install-tl-unx.tar.gz
    cd install-tl-*/
    ./install-tl --texdir ~/texlive --no-interaction --scheme=scheme-basic
    ~/texlive/bin/x86_64-linux/tlmgr install type1cm cm-super dvipng collection-fontsrecommended amsmath amssymb bm underscore xcolor

    Add to your ~/.bashrc (or equivalent):

    export PATH="$HOME/texlive/bin/x86_64-linux:$PATH"

    Verify:

    python -c "
    import matplotlib; matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    plt.figure(); plt.title(r'\$\sum_{i=1}^n x_i\$')
    plt.savefig('/tmp/tex_test.png')
    print('Success!')
    "

Ways You Can Contribute:

  • If you'd like to contribute new features, bug fixes, or improvements to the documentation, please refer to our contributing guide for detailed instructions on how to get started.

  • You can also contribute by adding new methods, datasets, or configurations that improve the current performance of a method in the benchmark section.

Citation

If you use stable-pretraining in your research, please cite:

@article{balestriero2025stable,
  title={stable-pretraining-v1: Foundation Model Research Made Simple},
  author={Balestriero, Randall and Van Assel, Hugues and BuGhanem, Sami and Maes, Lucas},
  journal={arXiv preprint arXiv:2511.19484},
  year={2025}
}

Contributors

Core contributors (in order of joining the project):