Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ linear_probe = spt.callbacks.OnlineProbe(
input="embedding", # Which output from forward to monitor
target="label", # Ground truth from batch
probe=torch.nn.Linear(512, 10),
loss_fn=torch.nn.CrossEntropyLoss(),
loss=torch.nn.CrossEntropyLoss(),
metrics={
"top1": torchmetrics.classification.MulticlassAccuracy(10),
"top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5),
Expand All @@ -143,6 +143,7 @@ knn_probe = spt.callbacks.OnlineKNN(
input="embedding",
target="label",
queue_length=20000,
metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(10)},
k=10,
)
```
Expand Down Expand Up @@ -282,7 +283,7 @@ linear_probe = spt.callbacks.OnlineProbe(
input="embedding",
target="label",
probe=torch.nn.Linear(512, 10),
loss_fn=torch.nn.CrossEntropyLoss(),
loss=torch.nn.CrossEntropyLoss(),
metrics={
"top1": torchmetrics.classification.MulticlassAccuracy(10),
"top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5),
Expand Down Expand Up @@ -326,14 +327,14 @@ The `spt` command launches training from YAML configuration files using Hydra.

```bash
# Run with a config file
spt examples/simclr_cifar10_config.yaml
spt run examples/simclr_cifar10_config.yaml

# With parameter overrides
spt examples/simclr_cifar10_config.yaml trainer.max_epochs=50 module.optim.lr=0.01
spt run examples/simclr_cifar10_config.yaml trainer.max_epochs=50 module.optim.lr=0.01

# Run from any directory - supports absolute and relative paths
spt ../configs/my_config.yaml
spt /path/to/config.yaml
spt run ../configs/my_config.yaml
spt run /path/to/config.yaml
```

### SLURM Cluster Training
Expand All @@ -342,10 +343,10 @@ For training on SLURM clusters, use the `-m` flag to enable multirun mode:

```bash
# Use the provided SLURM template (customize partition/QOS in the file)
spt examples/simclr_cifar10_slurm.yaml -m
spt run examples/simclr_cifar10_slurm.yaml -m

# Override SLURM parameters via command line
spt examples/simclr_cifar10_slurm.yaml -m \
spt run examples/simclr_cifar10_slurm.yaml -m \
hydra.launcher.partition=gpu \
hydra.launcher.qos=normal \
hydra.launcher.timeout_min=720
Expand Down Expand Up @@ -378,12 +379,10 @@ The library is not yet available on PyPI. You can install it from the source cod
uv pip install -e . # Core dependencies only
```

For optional features (vision models, experiment tracking, cluster support, etc.):
For development (tests, linting, docs):
```bash
uv pip install -e ".[vision,tracking]" # Example: add vision models and wandb
uv pip install -e ".[all]" # Or install all optional dependencies
uv pip install -e ".[dev]"
```
See `pyproject.toml` for available dependency groups (`vision`, `tracking`, `cluster`, `visualization`, `datasets`, `extras`, `dev`, `doc`).

If you do not want to use uv, simply remove it from the above commands.

Expand Down
Loading