Skip to content

HDmoonSir/StableDiffusionWithCustomNetwork

Repository files navigation

Stable Diffusion with Custom Network for Tabular Data

This project fine-tunes a Stable Diffusion model to generate images based on tabular data inputs, instead of traditional text prompts. It utilizes a custom neural network to process tabular data and condition the diffusion model's UNet.

Features

  • Tabular Data Input: Generates images from numerical or categorical data.
  • Custom Conditioning: A custom network transforms tabular input into embeddings for the diffusion model.
  • Stable Diffusion v1.5: Built upon the robust and popular Stable Diffusion v1.5 model.
  • Mixed Precision Training: Supports float16 and bfloat16 mixed precision training for efficient memory usage and faster training.
  • Gradient Clipping: Implements gradient clipping to prevent gradient explosion during training.
  • Fine-Tuning: Modifies the UNet part of the Stable Diffusion pipeline to understand the new conditioning.
  • Easy Configuration: All experiment, model, and training parameters are managed through a single default.yaml file.
  • Ready-to-use Scripts: Includes train.py for training and inference.py for generating images with a trained model.
  • Reproducibility: Includes a Dockerfile for building a containerized environment.

Prerequisites

  • Python 3.8+
  • PyTorch
  • Git and Git LFS

Setup

  1. Clone the repository:

    git clone https://github.com/your-username/your-repo-name.git
    cd your-repo-name
  2. Install dependencies:

    pip install -r requirements.txt
  3. Download Pre-trained Stable Diffusion v1.5: This project requires the original weights for Stable Diffusion v1.5. You need to download them and place them in a directory. The default configuration in default.yaml expects them at ../../../weight_original/stable-diffusion-v1-5.

    You can download the weights from the Hugging Face Hub: stable-diffusion-v1-5

    Make sure you have git-lfs installed to download the model weights correctly.

Dataset Preparation

The dataloader expects a specific structure for the dataset, as shown in the data_example/sushi directory:

  • images/: A folder containing all the image files.
  • metadata.jsonl: A JSON Lines file where each line is a dictionary containing file_name and custom_input.

Example metadata.jsonl line:

{"file_name": "sushi_0.jpg", "custom_input": [1, 0, 0, 0, 0]}
  • file_name: The name of the image file in the images folder.
  • custom_input: A list of numbers representing your tabular data. The length of this list must match model_custom.num_input_layer in your configuration file.

Update the data.path_dataset in default.yaml to point to your dataset directory.

Configuration (default.yaml)

All settings are controlled via default.yaml. Key parameters to configure:

  • experiment: Settings for experiment name, save directory, and seed.
  • model_custom:
    • num_input_layer: The number of features in your tabular data input.
    • path_sd: Path to the pre-trained Stable Diffusion v1.5 weights.
  • train: Training parameters like epochs, batch size, and learning rates.
    • dtype: Data type for training. Supports "float32", "float16", and "bfloat16". Setting to "float16" or "bfloat16" enables mixed precision training automatically if CUDA is available.
  • data:
    • path_dataset: Path to your dataset directory.

Training

To start training the model, run the train.py script. The script will:

  1. Load the configuration from default.yaml.
  2. Set up an experiment directory under exp_results with a timestamp.
  3. Save the configuration and code snapshots for reproducibility.
  4. Start the training process.
python train.py

Checkpoints and the best model will be saved in the experiment directory (e.g., ../../../exp_results/YYYYMMDD_HHMMSS_exp).

Inference

After training, you can generate images using inference.py.

  1. Set the model path: Open inference.py and modify the path_infer variable to point to your experiment directory.

    # in inference.py
    path_infer = "../../../exp_results/your_experiment_timestamp_exp"
  2. Provide input data: Modify the list_values variable with the tabular data you want to generate images from. The input should be a string of space-separated numbers.

    # in inference.py
    values_input_0 = "0 1 0 0 1"
    values_input_1 = "1 0 0 0 0"
    list_values = [values_input_0, values_input_1]
  3. Run the script:

    python inference.py

Generated images will be saved in the img_gen directory.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages