Skip to content

tmlr-group/MMD-HSIC-DUAL

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 

Repository files navigation

MMD-DUAL

This package implements the MMD-DUAL test proposed in Learning Diverse Kernels for Aggregated Two-sample and Independence Testing by Zhou, Tian, Peng, Lei, Schrab, Sutherland and Liu.

For full reproducibility experiments, please see full repo.

Requirements

  • python 3.9+
  • torch
  • numpy
  • scipy

The code automatically detects and uses available hardware accelerators:

  • CUDA GPU (Nvidia GPU)
  • Apple MPS (APPLE GPU)
  • CPU

Installation

First, create a conda environment:

conda create -n mmd-dual python=3.9
conda activate mmd-dual

Install PyTorch with CUDA support (for GPU acceleration):

# For CUDA 11.8
pip install torch --index-url https://download.pytorch.org/whl/cu118

# For CUDA 12.1
pip install torch --index-url https://download.pytorch.org/whl/cu121

Or install PyTorch for CPU/MPS:

# For CPU only
pip install torch --index-url https://download.pytorch.org/whl/cpu

# For macOS (MPS support included by default)
pip install torch

Install the remaining dependencies:

pip install numpy scipy

Clone the repository:

git clone git@github.com:yeager20001118/MMD-HSIC-DUAL.git
cd MMD-HSIC-DUAL

MMD-DUAL

The MMD-DUAL test is a two-sample test that learns diverse kernels via aggregation. It consists of two phases:

  1. Training phase: Learn optimal kernel bandwidths from training data
  2. Testing phase: Perform the two-sample test on test data

Example

>>> import torch
>>> from MMD_DUAL.MMD_DUAL import train_MMD_DUAL, TST_MMD_DUAL
>>> # Generate sample data
>>> X_train = torch.randn(250, 2)
>>> Y_train = torch.randn(250, 2) + 0.5  # Shifted distribution
>>> # Train the DUAL model
>>> model = train_MMD_DUAL(X_train, Y_train, rs=0)
>>> # Generate test data
>>> X_test = torch.randn(250, 2)
>>> Y_test = torch.randn(250, 2) + 0.5
>>> # Run the two-sample test
>>> reject = TST_MMD_DUAL(X_test, Y_test, rs=0, n_per=1000, model_DUAL=model)
>>> print(f"Reject null hypothesis: {reject}")

Function Parameters

train_MMD_DUAL

Parameter Type Default Description
X_train Tensor - Training samples from distribution P
Y_train Tensor - Training samples from distribution Q
rs int - Random seed
N_epoch int 1000 Number of training epochs
batch_size int 100 Batch size for training
learning_rate float 0.0005 Learning rate
n_bandwidth list [10, 10] Number of bandwidths for [gaussian, laplacian] kernels
reg float 1e-8 Regularization parameter
way list ['Agg', 'Fuse'] Bandwidth selection method for each kernel
is_cov bool True Whether to train with covariance

TST_MMD_DUAL

Parameter Type Default Description
X_test Tensor - Test samples from distribution P
Y_test Tensor - Test samples from distribution Q
rs int - Random seed
n_per int 100 Number of permutations/bootstrap samples
alpha float 0.05 Significance level
is_wild bool True Use wild bootstrap (True) or permutation (False)
model_DUAL DUAL None Trained DUAL model

Returns True if the null hypothesis is rejected (distributions are different), False otherwise.

Contact

If you have any issues running our MMD-DUAL test, please feel free to contact Xunye Tian. We will later update the well-prepared HSIC-DUAL in the future, or you can firstly use the complete version in the reproducible repo.

Affiliations

Faculty of Engineering and IT, University of Melbourne

School of Mathematics and Statistics, University of Melbourne

Department of Computer Science and Technology, University of Cambridge

Faculty of Science, Computer Science, University of British Columnbia

Amii Canada

Bibtex

@inproceedings{zhou2025dual,
  title={DUAL: Learning Diverse Kernels for Aggregated Two-sample and Independence Testing},
  author={Zhou, Zhijian and Tian, Xunye and Peng, Liuhua and Lei, Chao and Schrab, Antonin and Sutherland, Danica J and Liu, Feng},
  booktitle={NeurIPS},
  year={2025}
}

License

MIT License (see LICENSE)

About

[NeurIPS 2025] "DUAL: Learning Diverse Kernels for Aggregated Two-sample and Independence Testing"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%