|
| 1 | +# Copyright 2023 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# This sentinel is a reminder to choose a real run name. |
| 16 | +run_name: '' |
| 17 | + |
| 18 | +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. |
| 19 | +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ |
| 20 | +write_metrics: True |
| 21 | +gcs_metrics: False |
| 22 | +# If true save config to GCS in {base_output_directory}/{run_name}/ |
| 23 | +save_config_to_gcs: False |
| 24 | +log_period: 100 |
| 25 | + |
| 26 | +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' |
| 27 | + |
| 28 | +# Flux params |
| 29 | +flux_name: "flux-dev" |
| 30 | +max_sequence_length: 512 |
| 31 | +time_shift: True |
| 32 | +base_shift: 0.5 |
| 33 | +max_shift: 1.15 |
| 34 | +# offloads t5 encoder after text encoding to save memory. |
| 35 | +offload_encoders: True |
| 36 | + |
| 37 | + |
| 38 | +unet_checkpoint: '' |
| 39 | +revision: 'refs/pr/95' |
| 40 | +# This will convert the weights to this dtype. |
| 41 | +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' |
| 42 | +weights_dtype: 'bfloat16' |
| 43 | +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) |
| 44 | +activations_dtype: 'bfloat16' |
| 45 | + |
| 46 | +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision |
| 47 | +# Options are "DEFAULT", "HIGH", "HIGHEST" |
| 48 | +# fp32 activations and fp32 weights with HIGHEST will provide the best precision |
| 49 | +# at the cost of time. |
| 50 | +precision: "DEFAULT" |
| 51 | + |
| 52 | +# if False state is not jitted and instead replicate is called. This is good for debugging on single host |
| 53 | +# It must be True for multi-host. |
| 54 | +jit_initializers: True |
| 55 | + |
| 56 | +# Set true to load weights from pytorch |
| 57 | +from_pt: True |
| 58 | +split_head_dim: True |
| 59 | +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te |
| 60 | + |
| 61 | +flash_block_sizes: {} |
| 62 | +# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. |
| 63 | +# flash_block_sizes: { |
| 64 | +# "block_q" : 1536, |
| 65 | +# "block_kv_compute" : 1536, |
| 66 | +# "block_kv" : 1536, |
| 67 | +# "block_q_dkv" : 1536, |
| 68 | +# "block_kv_dkv" : 1536, |
| 69 | +# "block_kv_dkv_compute" : 1536, |
| 70 | +# "block_q_dq" : 1536, |
| 71 | +# "block_kv_dq" : 1536 |
| 72 | +# } |
| 73 | +# GroupNorm groups |
| 74 | +norm_num_groups: 32 |
| 75 | + |
| 76 | +# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch |
| 77 | +# else they will be loaded from pretrained_model_name_or_path |
| 78 | +train_new_unet: False |
| 79 | + |
| 80 | +# train text_encoder - Currently not supported for SDXL |
| 81 | +train_text_encoder: False |
| 82 | +text_encoder_learning_rate: 4.25e-6 |
| 83 | + |
| 84 | +# https://arxiv.org/pdf/2305.08891.pdf |
| 85 | +snr_gamma: -1.0 |
| 86 | + |
| 87 | +timestep_bias: { |
| 88 | + # a value of later will increase the frequence of the model's final training steps. |
| 89 | + # none, earlier, later, range |
| 90 | + strategy: "none", |
| 91 | + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. |
| 92 | + multiplier: 1.0, |
| 93 | + # when using strategy=range, the beginning (inclusive) timestep to bias. |
| 94 | + begin: 0, |
| 95 | + # when using strategy=range, the final step (inclusive) to bias. |
| 96 | + end: 1000, |
| 97 | + # portion of timesteps to bias. |
| 98 | + # 0.5 will bias one half of the timesteps. Value of strategy determines |
| 99 | + # whether the biased portions are in the earlier or later timesteps. |
| 100 | + portion: 0.25 |
| 101 | +} |
| 102 | + |
| 103 | +# Override parameters from checkpoints's scheduler. |
| 104 | +diffusion_scheduler_config: { |
| 105 | + _class_name: 'FlaxEulerDiscreteScheduler', |
| 106 | + prediction_type: 'epsilon', |
| 107 | + rescale_zero_terminal_snr: False, |
| 108 | + timestep_spacing: 'trailing' |
| 109 | +} |
| 110 | + |
| 111 | +# Output directory |
| 112 | +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" |
| 113 | +base_output_directory: "" |
| 114 | + |
| 115 | +# Hardware |
| 116 | +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' |
| 117 | + |
| 118 | +# Parallelism |
| 119 | +mesh_axes: ['data', 'fsdp', 'tensor'] |
| 120 | + |
| 121 | +# batch : batch dimension of data and activations |
| 122 | +# hidden : |
| 123 | +# embed : attention qkv dense layer hidden dim named as embed |
| 124 | +# heads : attention head dim = num_heads * head_dim |
| 125 | +# length : attention sequence length |
| 126 | +# temb_in : dense.shape[0] of resnet dense before conv |
| 127 | +# out_c : dense.shape[1] of resnet dense before conv |
| 128 | +# out_channels : conv.shape[-1] activation |
| 129 | +# keep_1 : conv.shape[0] weight |
| 130 | +# keep_2 : conv.shape[1] weight |
| 131 | +# conv_in : conv.shape[2] weight |
| 132 | +# conv_out : conv.shape[-1] weight |
| 133 | +logical_axis_rules: [ |
| 134 | + ['batch', 'data'], |
| 135 | + ['activation_batch', ['data','fsdp']], |
| 136 | + ['activation_heads', 'tensor'], |
| 137 | + ['activation_kv', 'tensor'], |
| 138 | + ['mlp','tensor'], |
| 139 | + ['embed','fsdp'], |
| 140 | + ['heads', 'tensor'], |
| 141 | + ['conv_batch', ['data','fsdp']], |
| 142 | + ['out_channels', 'tensor'], |
| 143 | + ['conv_out', 'fsdp'], |
| 144 | + ] |
| 145 | +data_sharding: [['data', 'fsdp', 'tensor']] |
| 146 | + |
| 147 | +# One axis for each parallelism type may hold a placeholder (-1) |
| 148 | +# value to auto-shard based on available slices and devices. |
| 149 | +# By default, product of the DCN axes should equal number of slices |
| 150 | +# and product of the ICI axes should equal number of devices per slice. |
| 151 | +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded |
| 152 | +dcn_fsdp_parallelism: -1 |
| 153 | +dcn_tensor_parallelism: 1 |
| 154 | +ici_data_parallelism: -1 |
| 155 | +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded |
| 156 | +ici_tensor_parallelism: 1 |
| 157 | + |
| 158 | +# Dataset |
| 159 | +# Replace with dataset path or train_data_dir. One has to be set. |
| 160 | +dataset_name: 'diffusers/pokemon-gpt4-captions' |
| 161 | +train_split: 'train' |
| 162 | +dataset_type: 'tf' |
| 163 | +cache_latents_text_encoder_outputs: True |
| 164 | +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", |
| 165 | +# only apply to small dataset that fits in memory |
| 166 | +# prepare image latents and text encoder outputs |
| 167 | +# Reduce memory consumption and reduce step time during training |
| 168 | +# transformed dataset is saved at dataset_save_location |
| 169 | +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' |
| 170 | +train_data_dir: '' |
| 171 | +dataset_config_name: '' |
| 172 | +jax_cache_dir: '' |
| 173 | +hf_data_dir: '' |
| 174 | +hf_train_files: '' |
| 175 | +hf_access_token: '' |
| 176 | +image_column: 'image' |
| 177 | +caption_column: 'text' |
| 178 | +resolution: 1024 |
| 179 | +center_crop: False |
| 180 | +random_flip: False |
| 181 | +# If cache_latents_text_encoder_outputs is True |
| 182 | +# the num_proc is set to 1 |
| 183 | +tokenize_captions_num_proc: 4 |
| 184 | +transform_images_num_proc: 4 |
| 185 | +reuse_example_batch: False |
| 186 | +enable_data_shuffling: True |
| 187 | + |
| 188 | +# checkpoint every number of samples, -1 means don't checkpoint. |
| 189 | +checkpoint_every: -1 |
| 190 | +# enables one replica to read the ckpt then broadcast to the rest |
| 191 | +enable_single_replica_ckpt_restoring: False |
| 192 | + |
| 193 | +# Training loop |
| 194 | +learning_rate: 4.e-7 |
| 195 | +scale_lr: False |
| 196 | +max_train_samples: -1 |
| 197 | +# max_train_steps takes priority over num_train_epochs. |
| 198 | +max_train_steps: 200 |
| 199 | +num_train_epochs: 1 |
| 200 | +seed: 0 |
| 201 | +output_dir: 'sdxl-model-finetuned' |
| 202 | +per_device_batch_size: 1 |
| 203 | + |
| 204 | +warmup_steps_fraction: 0.0 |
| 205 | +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. |
| 206 | + |
| 207 | +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before |
| 208 | +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. |
| 209 | + |
| 210 | +# AdamW optimizer parameters |
| 211 | +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. |
| 212 | +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. |
| 213 | +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. |
| 214 | +adam_weight_decay: 1.e-2 # AdamW Weight decay |
| 215 | +max_grad_norm: 1.0 |
| 216 | + |
| 217 | +enable_profiler: False |
| 218 | +# Skip first n steps for profiling, to omit things like compilation and to give |
| 219 | +# the iteration time a chance to stabilize. |
| 220 | +skip_first_n_steps_for_profiler: 5 |
| 221 | +profiler_steps: 10 |
| 222 | + |
| 223 | +# Generation parameters |
| 224 | +prompt: "A magical castle in the middle of a forest, artistic drawing" |
| 225 | +prompt_2: "A magical castle in the middle of a forest, artistic drawing" |
| 226 | +negative_prompt: "purple, red" |
| 227 | +do_classifier_free_guidance: True |
| 228 | +guidance_scale: 3.5 |
| 229 | +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf |
| 230 | +guidance_rescale: 0.0 |
| 231 | +num_inference_steps: 50 |
| 232 | + |
| 233 | +# SDXL Lightning parameters |
| 234 | +lightning_from_pt: True |
| 235 | +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. |
| 236 | +lightning_repo: "" |
| 237 | +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. |
| 238 | +lightning_ckpt: "" |
| 239 | + |
| 240 | +# LoRA parameters |
| 241 | +# Values are lists to support multiple LoRA loading during inference in the future. |
| 242 | +lora_config: { |
| 243 | + lora_model_name_or_path: [], |
| 244 | + weight_name: [], |
| 245 | + adapter_name: [], |
| 246 | + scale: [], |
| 247 | + from_pt: [] |
| 248 | +} |
| 249 | +# Ex with values: |
| 250 | +# lora_config : { |
| 251 | +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], |
| 252 | +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], |
| 253 | +# adapter_name: ["hyper-sdxl"], |
| 254 | +# scale: [0.7], |
| 255 | +# from_pt: [True] |
| 256 | +# } |
| 257 | + |
| 258 | +enable_mllog: False |
| 259 | + |
| 260 | +#controlnet |
| 261 | +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' |
| 262 | +controlnet_from_pt: True |
| 263 | +controlnet_conditioning_scale: 0.5 |
| 264 | +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' |
| 265 | +quantization: '' |
| 266 | +# Shard the range finding operation for quantization. By default this is set to number of slices. |
| 267 | +quantization_local_shard_count: -1 |
| 268 | +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. |
| 269 | + |
0 commit comments