A library for measuring and analyzing distributional symmetry breaking in machine learning models.
This project is licensed under the MIT License - see the LICENSE file for details.
dist-symm-breaking/
├── src/dsb/ # Main Python package (pip installable)
│ ├── datasets.py # Dataset classes and registry
│ ├── models.py # Model architectures
│ ├── train.py # Training and evaluation
│ ├── utils.py # Utility functions
│ ├── baselines.py # Baseline methods
│ └── configs/ # Hydra configuration files
├── scripts/ # Entry points for running experiments
│ ├── main.py # Main CLI entry point
│ └── test_readme_commands.py # Automated testing
├── analysis/ # Analysis and visualization scripts
├── notebooks/ # Jupyter notebooks for exploration
├── ridge_regression/ # Theoretical visualizations (see below)
└── crystal-text-llm/ # Modified LLM for crystal experiments (see below)
ridge_regression/: Contains scripts for visualizing the theoretical concepts from the paper. These are standalone visualizations of the ridge regression analysis and are not part of the core method.
crystal-text-llm/: A minimal modification of facebook/crystal-text-llm used for one specific experiment in the paper. This is not part of the main DSB method—it's included only to reproduce that particular experiment.
# From the repository root directory
pip install -e .
# Or with uv
uv pip install -e .This installs the dsb package, allowing you to import it from anywhere:
from dsb import DATASET_REGISTRY, train_model
from dsb.datasets import MetaAugmentedDataset
from dsb.models import MoleculeNetSet the DATA_DIR environment variable to point to your data directory:
export DATA_DIR=/path/to/your/dataDatasets are expected in subdirectories named after the dataset (e.g., $DATA_DIR/mnist/). Wandb logging is enabled by default but can be disabled with ++training.use_wandb=False.
After installation, run experiments from the scripts/ directory:
cd scripts
# Basic vanilla command - MNIST classification with default settings
python main.py --config-name=mnist_classification ++training.use_wandb=False ++save_dir=test_run
# Swiss Roll classification
python main.py --config-name=swiss_roll_classification ++save_dir=test_run
# QM9 regression with e3nn model
python main.py --config-name=qm9_regression model=e3convnet ++save_dir=test_run# MNIST digit classification (simplest example)
python main.py --config-name=mnist_classification ++training.epochs=10 ++training.use_wandb=False ++save_dir=mnist_test
# Swiss Roll binary classification
python main.py --config-name=swiss_roll_classification ++training.epochs=50 ++training.use_wandb=False ++save_dir=swissroll_test
# Detection metric on MNIST
python main.py --config-name=mnist_detection ++training.epochs=10 ++training.use_wandb=False ++save_dir=mnist_detect_testWe use Hydra for configuration management. Configs are hierarchical and combined at runtime.
Key syntax:
++field=value— Override an existing field or create if missing+field=value— Add a new field (errors if it exists)model=e3convnet— Select a config file fromconfigs/model/
Example:
python main.py +save_dir='run_name' model=molnet ++training.epochs=20A test script automatically extracts and runs all python main.py commands from this README:
# Dry run - see what commands would be executed
python scripts/test_readme_commands.py --dry-run
# Run all tests (1 epoch each, wandb disabled)
python scripts/test_readme_commands.py
# Test specific datasets
python scripts/test_readme_commands.py --datasets mnist swiss_roll
# Stop on first failure (useful for debugging)
python scripts/test_readme_commands.py --stop-on-failure --keep-failedCleanup options (to avoid accumulating test outputs):
# Recommended: keep outputs only for failed tests (for debugging)
python scripts/test_readme_commands.py --keep-failed
# Keep only log files, remove model checkpoints
python scripts/test_readme_commands.py --keep-logs
# Remove all outputs after tests complete
python scripts/test_readme_commands.py --cleanupThis repo uses nbstripout as a git filter to strip notebook outputs at commit time. Your local notebook outputs are preserved — only the staged/committed version is stripped.
New contributors should set up the filter after cloning:
pip install nbstripout
nbstripout --installSeveral configs and scripts reference DATA_DIR and WANDB_USER via environment variables. Make sure these are set before running experiments or notebooks:
export DATA_DIR=/path/to/your/data
export WANDB_USER=your_wandb_entityDatasets are registered via the DATASET_REGISTRY pattern in src/dsb/datasets.py. Each dataset defines a DatasetConfig subclass that encapsulates all dataset-specific logic.
Add your dataset class to src/dsb/datasets.py. It should have basic functionality: __getitem__(self, idx) should return a pair (datapoint, idx).
You also need to define the associated operators to canonicalize and label your data:
-
canonicalize_operator(data, idx)— takes as input data and idx, and outputs the canonicalized data. -
label_operator(data)— takes as input just a datapoint, and outputs its label. (For standard torchvision datasets, this operator is trivial, as the datapoint already includes the label as the second element of a tuple, but for molecular datasets the label may need to be extracted in a dataset-dependent way.) -
transform_operator(data, idx)— takes as input just a datapoint (again, however it is returned by the base dataset), and transforms it. This allows for both invariance and equivariance, depending on how it is defined. You can look atToyCircleDatasetas an example.
For equivariant quantities (vectors, tensors), the transform operator must co-rotate all relevant fields alongside positions — see rotate_qm7_quantities in utils.py for an example.
class MyDatasetConfig(DatasetConfig):
def create_dataset(self, dataset_dir, cfg):
# Instantiate dataset, compute means/stds if needed
base_dataset = MyDataset(root=dataset_dir)
means, stds = compute_means_stds(base_dataset)
return base_dataset, means, stds
def get_label_operator(self, base_dataset, task, transform_operator=None):
# Return a callable: data -> (data, label)
# Override for task-specific behavior (e.g. predict_g)
return base_dataset.label_operator
def get_criterion(self, task, dataset_cfg, means=None, stds=None):
# Return the primary loss (e.g. nn.MSELoss(), nn.CrossEntropyLoss())
return nn.MSELoss()
def get_aux_criteria(self, task, dataset_cfg, means=None, stds=None):
# Return a dict of auxiliary metrics: {name: callable(outputs, targets)}
return {'MAE': lambda o, t: torch.mean(torch.abs(o - t)).item()}
def get_split_function(self, dataset_name, dataset_cfg):
# Return 'standard' or a custom split name
return 'standard'DATASET_REGISTRY['my_new_dataset'] = MyDatasetConfig()Add a model class to src/dsb/models.py. The class name of this model will be used in the config.
Create config files in src/dsb/configs/:
configs/dataset/my_new_dataset.yaml— dataset parameters (name,task,target,num_classes, etc.)configs/model/my_model.yaml— model parametersconfigs/my_new_dataset_task.yaml— top-level config combining dataset, model, and training
If you would like to use the task-dependent metric, you also need to add (1) a model that learns to output a canonicalization c(x), (2) a binary classification model that takes as input pairs (c(x), y), and/or (3) a prediction model that takes as input just c(x) and tries to predict y.
Note that the train- and test-time augmentations can be toggled on or off, as shown below.
python main.py --config-name=swiss_roll_classification ++dataset.prob=1.0 ++dataset.augment_args.do_augment=True ++dataset.augment_args.transform=1.0 ++dataset.augment_args.train=True ++dataset.augment_args.val=False ++dataset.augment_args.test=False ++save_dir=ignore
python main.py --config-name=swiss_roll_detection ++dataset.prob=1.0 ++dataset.augment_args.do_augment=False ++save_dir=ignore
python main.py --config-name=swiss_roll_task_detection ++dataset.prob=1.0 ++dataset.augment_args.do_augment=False ++dataset.task_dependent_args.c_args.learned=False ++save_dir=ignore
python main.py --config-name=swiss_roll_task_direct ++dataset.prob=1.0 ++dataset.augment_args.do_augment=False ++dataset.task_dependent_args.c_args.learned=False ++save_dir=ignore
As before, one can turn augmentation on/off in the same way.
python main.py --config-name=mnist_classification ++save_dir=ignore
To run the group averaged model, use the config mnist_classification_c4_av.
python main.py --config-name=mnist_detection ++dataset.augment_args.do_augment=False ++save_dir=ignore
python main.py --config-name=mnist_task_detection ++dataset.augment_args.do_augment=False ++dataset.task_dependent_args.c_args.learned=False ++save_dir=ignore
python main.py --config-name=mnist_task_direct ++dataset.augment_args.do_augment=False ++dataset.task_dependent_args.c_args.learned=False ++save_dir=ignore
Note that the model can be changed, as shown below. Using the qm9_atomic dataset, we are using the standard Anderson splits (i.e. from Cormorant, which were subsequently used by EDM and its follow-ups). The property to predict can be changed by setting dataset.target, as shown below.
python main.py --config-name=qm9_regression model=e3convnet ++dataset.augment_args.do_augment=True ++dataset.augment_args.transform=1.0 ++dataset.augment_args.train=True ++dataset.augment_args.val=False ++dataset.augment_args.test=False ++dataset.name='qm9_atomic' ++dataset.split_args.split_type='anderson' ++dataset.target='U0' ++save_dir=ignore
To run the e3nn model, add model=e3convnet. To run the so3 group averaged model, change the config name to qm9_regression_so3_ave.
python main.py --config-name=qm9_detection model=transformer ++dataset.augment_args.do_augment=False ++save_dir=ignore
Note that we use the ordinary QM9 dataset from torch geometric for this experiment.
python main.py --config-name=qm9_predict_g model=transformer ++dataset.augment_args.do_augment=False ++save_dir=ignore ++dataset.split_args.split_type=ignore
python main.py --config-name=qm9_task_detection ++dataset.augment_args.do_augment=False ++dataset.task_dependent_args.c_args.learned=False ++save_dir=ignore
python main.py --config-name=qm9_task_direct ++dataset.augment_args.do_augment=False ++dataset.task_dependent_args.c_args.learned=False ++save_dir=ignore
python main.py --config-name=local_qm9_detection model=transformer ++dataset.augment_args.do_augment=False ++save_dir=ignore
python main.py --config-name=local_qm9_predict_g model=transformer ++dataset.augment_args.do_augment=False ++save_dir=ignore
The augment_args flags control data pre-processing applied before the detection metric's own transform:
do_augment=False— Original: raw data, no pre-processing.do_augment=True, transform=1.0— Random rotations: every sample is randomly rotated before detection.do_augment=True, canonicalize=1.0— Canonicalized: every sample is PCA-canonicalized before detection.
Set train=True, val=True, test=True to apply the augmentation to all splits.
python main.py --config-name=local_qm9_detection model=transformer ++dataset.augment_args.do_augment=False ++save_dir=ignore
python main.py --config-name=local_qm9_detection model=transformer ++dataset.augment_args.do_augment=True ++dataset.augment_args.transform=1.0 ++dataset.augment_args.train=True ++dataset.augment_args.val=True ++dataset.augment_args.test=True ++save_dir=ignore
python main.py --config-name=local_qm9_detection model=transformer ++dataset.augment_args.do_augment=True ++dataset.augment_args.canonicalize=1.0 ++dataset.augment_args.train=True ++dataset.augment_args.val=True ++dataset.augment_args.test=True ++save_dir=ignore
python main.py --config-name=qm9_detection model=transformer ++dataset.augment_args.do_augment=True ++dataset.augment_args.transform=1.0 ++dataset.augment_args.train=True ++dataset.augment_args.val=True ++dataset.augment_args.test=True ++save_dir=ignore
python main.py --config-name=qm9_detection model=transformer ++dataset.augment_args.do_augment=True ++dataset.augment_args.canonicalize=1.0 ++dataset.augment_args.train=True ++dataset.augment_args.val=True ++dataset.augment_args.test=True ++save_dir=ignore
To run all 25 trials with early stopping and wandb logging:
bash scripts/run_qm9_detection_settings.shResults are saved to $DATA_DIR/checkpoints/results/qm9_detection_settings/.
Data is based on the QM7b dataset, taken from this paper that calculated response properties at varying levels of DFT accuracy. Data can be downloaded from https://archive.materialscloud.org/record/2019.0002/v3 (CCSD_daDZ.tar.gz). Untar the resulting .xyz files into $DATA_DIR/qm7/.
python main.py --config-name=qm7_regression ++training.use_wandb=True model=graphormer ++training.epochs=500 ++save_dir=ignore
Uses 1x0e+1x2e irreps output to predict the polarizability tensor in equivariant form. Override dataset.target=quadrupole to predict the quadrupole tensor instead.
python main.py --config-name=qm7_regression_tensor ++training.use_wandb=False ++save_dir=ignore
Note: For e3convnet, use the config model/e3convnet_qm7. Change irreps_out depending on whether you're predicting a scalar or higher-order property. For the group averaged model, use --config-name=qm7_regression_so3_ave and set model_args.output_is_vector: False for scalar outputs and ++model.model_args.output_is_vector=True for vector outputs (e.g. dipole). Without output_is_vector=True, the SO3 averaging will rotate predictions to different frames before averaging, causing the vector predictions to cancel out.
For non-equivariant models (e.g. Graphormer), set ++dataset.to_irreps=false so labels stay in flat [xx, yy, zz, xy, xz, yz] format instead of irreps format. The output size remains 6 either way.
python main.py --config-name=qm7_regression_tensor model=graphormer ++dataset.to_irreps=false ++training.use_wandb=False ++save_dir=ignore
For dipole prediction with e3convnet, use irreps_out=1x1o (odd-parity vector, 3 components) and set num_classes=3. Note that to_irreps has no effect on dipole — the label is always a flat 3-vector.
python main.py --config-name=qm7_regression_tensor ++dataset.target=dipole ++dataset.num_classes=3 ++dataset.filter_polar=true ++model.model_args.irreps_out=1x1o ++training.use_wandb=False ++save_dir=ignore
python main.py --config-name=qm7_detection ++save_dir=ignore
python main.py --config-name=qm7_task_detection ++dataset.task_dependent_args.c_args.learned=False ++save_dir=ignore
python main.py --config-name=qm7_task_direct ++dataset.task_dependent_args.c_args.learned=False ++save_dir=ignore
Data setup: PyTorch Geometric downloads and processes ModelNet40 automatically on first run. Just set DATA_DIR to a writable directory — no manual download needed.
On first run, PyG will:
- Download
ModelNet40.zip(~400 MB) frommodelnet.cs.princeton.edu - Process all 12,311 meshes into
$DATA_DIR/modelnet/processed/(~7.4 GB total)
This one-time processing step takes 15–20 minutes (longer over NFS). Subsequent runs load the cached .pt files and start quickly.
python main.py dataset=modelnet_detection model=transformer_for_modelnet training=modelnet_training ++training.epochs=30
DATA_DIR=... python main.py dataset=modelnet_task_direct model=mlp_for_rot_c training=modelnet_training ++dataset.task_dependent_args.c_args.learned=False ++training.epochs=300 ++dataset.task_dependent_args.c_args.learned=False
DATA_DIR=... python main.py dataset=modelnet_task_detection model=mlp_for_rot_c_y training=modelnet_training ++dataset.task_dependent_args.c_args.learned=False ++training.epochs=1200
#Augmentation setting can be changed at augment_args in dataset config. DATA_DIR=... python main.py dataset=modelnet_classification model=transformer_for_modelnet training=modelnet_training ++model.hidden_dim=256 ++model.num_heads=8 ++model.num_layers=6 ++training.epochs=300 ++training.batch_size=128
At the moment, we have only implemented the task-independent metric for this dataset. The molecule is specified with dataset.mol_name.
python main.py --config-name=md17_detection ++save_dir=ignore
At the moment, we have only implemented the task-independent metric for this dataset. The catalyst or adsorbate can be filtered with dataset.filter_mol.
python main.py --config-name=oc20_detection ++save_dir=ignore