This is a template repository for training machine learning models for neural data using PyTorch Lightning. It implements a simple next-token prediction model on the Braintree Bank dataset, allowing for easy customization and extension with a structured setup for data handling, model definition, and training processes. It uses Neural libraries such as crane and torch brain to facilitate working with neural data and features.
-
Install the required dependencies (we recommend using uv for managing virtual environments):
uv sync # if not using uv, use: pip install -e .[dev] with a virtual environment -
Download and preprocess the Braintree Bank dataset (~230 GB) using the provided SLURM script (
scripts/data.sh) or manually via brainsets, and update your.envfile with the path to the dataset. Note you may need to adjust the script andscripts/env.shto fit your cluster setup.ROOT_DIR=/path/to/braintree_bank_dataset -
Modify the model, training logic, and configuration files as needed (see below). Especially, you will want to set the
wandb.projectandwandb.entityin the configuration to log your training runs to Weights & Biases. -
Run the training script:
uv run -m train [CLI overrides] # if not using uv, use: python -m trainor if using SLURM:
sbatch scripts/train.sh [CLI overrides]
configs/: Contains configuration files for different training setups. We use Hydra for configuration management, allowing easy CLI overrides and organization.train/: Contains the source code for data modules, models, and training scripts.model.py: Model architecture.data_module.py: PyTorch Lightning DataModule for loading and preprocessing the dataset.featurizer.py: Code for feature extraction from raw neural data. Will be applied to each sample in the dataset to create the input features for the model.pl_module.py: PyTorch Lightning module that wraps the model and defines training/validation steps.dataset.py: Data loading and preprocessing. You can probably ignore this for now.train.py: The main training script that orchestrates the training process.
scripts/: Scripts to run on a SLURM cluster. May need to be adjusted to your cluster configuration.data.sh: SLURM script to download and prepare the Braintree Bank dataset using brainsets.env.sh: SLURM script to set up the environment for training jobs, including loading modules and activating virtual environments. Update this to fit your cluster's environment management.train.sh: SLURM script to run the training job. Adjust the resource requests and environment setup as needed for your cluster, and pass any necessary CLI overrides for the training script.