Skip to content

maurock/vitax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ViTAX: A Vision Transformer in JAX

This repository contains the source code for ViTAX: Building a Vision Transformer from Scratch, a project implementing a Vision Transformer (ViT) using JAX and the NNX library. For a comprehensive explanation of the Vision Transformer architecture, the JAX/NNX implementation details, and a step-by-step walkthrough of the code, please refer to the blog post.

The project focuses on understanding each component of the ViT architecture and demonstrates how to build and train it for an image classification task using the "Matthijs/snacks" dataset from Hugging Face.

Video Preview
Click preview to watch the full video (2.4MB MP4)

Features

Setup and Installation

  1. Clone the repository:

    git clone https://github.com/maurock/vitax.git
    cd vitax
  2. Create and activate a conda environment:

    conda env create -f environment.yml
    conda activate vitax

Data Preparation

The model is trained on the "Matthijs/snacks" dataset from Hugging Face. The src/dataset.py script handles downloading, preprocessing (resizing), and saving the dataset to disk.

  1. Run the dataset preparation script:
    python src/dataset.py
    This script will:
    • Download the "Matthijs/snacks" dataset.
    • Resize images to the configured dimensions (default 160x160).
    • Save the processed dataset splits (train, validation, test) to data/snack_dataset.hf/.
    • It also saves a dataset_config.yaml within data/snack_dataset.hf/ which reflects the configuration used for dataset creation. This is important because the ViT will use this information during training.

🚂 Training

The training process is done in src/train.py script. It loads the preprocessed dataset, initializes the Vision Transformer model, and trains it using the Adam optimizer and cross-entropy loss.

  1. Configure training parameters (optional): Training hyperparameters (learning rate, batch size, number of epochs, model dimensions, etc.) are defined in configs/config.yml. You can modify this file to experiment with different settings.

    Here's a working example:

    # config.yml
    # Dataset. Set this according to your dataset configuration,
    # e.g. data/snack_dataset.hf/config.yaml
    width: 160
    height: 160
    path: "snack_dataset.hf" # path containing the split under data/
    
    # Model
    patch_size: 16
    num_classes: 20
    num_heads: 4
    num_encoder_blocks: 8
    dim_emb: 768
    
    # Training
    batch_size: 32
    learning_rate: 0.0001
    splits:
    - 'train'
    - 'validation'
    - 'test'
    num_epochs: 50
    output_name: my_weights  # stored in checkpoints/
    
  2. Run the training script:

    python src/train.py

    This script will:

    • Load the configuration from configs/config.yml.
    • Initialize the src.transformer.VisionTransformer model.
    • Load the preprocessed dataset from data/snack_dataset.hf/ (as prepared in the "Data Preparation" step).
    • Train the model using the src.train.run_training function, printing loss information for each epoch.
    • The trained model state is saved under checkpoints/. The name of the folder is the one provided in config.yaml -> output_name

License

This project is licensed under the MIT License. See the LICENSE text in pyproject.toml for more details.

About

Implementation of a Vision transformer in JAX NNX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages