Skip to content

Commit 6d54a84

Browse files
Wan vae implementation (#171)
Wan 2.1 VAE implementation + tests --------- Co-authored-by: Juan Acevedo <[email protected]>
1 parent 7284ca0 commit 6d54a84

21 files changed

+2235
-41
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ To generate images, run the following command:
171171
```bash
172172
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
173173
```
174+
174175
## Flux
175176

176177
First make sure you have permissions to access the Flux repos in Huggingface.

end_to_end/tpu/eval_assert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"""
2323

2424

25-
2625
# pylint: skip-file
2726
"""Reads and asserts over target values"""
2827
from absl import app
@@ -47,7 +46,7 @@ def test_final_loss(metrics_file, target_loss, num_samples_str="10"):
4746
target_loss = float(target_loss)
4847
num_samples = int(num_samples_str)
4948
with open(metrics_file, "r", encoding="utf8") as _:
50-
last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples)
49+
last_n_data = get_last_n_data(metrics_file, "learning/loss", num_samples)
5150
avg_last_n_data = sum(last_n_data) / len(last_n_data)
5251
print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}")
5352
print(f"Target loss is {target_loss}")

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@ huggingface_hub==0.30.2
3030
transformers==4.48.1
3131
einops==0.8.0
3232
sentencepiece
33-
aqtp
33+
aqtp
34+
imageio==2.37.0
35+
imageio-ffmpeg==0.6.0

requirements_with_jax_stable_stack.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,6 @@ tensorflow-datasets>=4.9.6
3131
tokenizers==0.21.0
3232
torch==2.5.1
3333
torchvision==0.20.1
34-
transformers==4.48.1
34+
transformers==4.48.1
35+
imageio==2.37.0
36+
imageio-ffmpeg==0.6.0
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This sentinel is a reminder to choose a real run name.
16+
run_name: ''
17+
18+
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
19+
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
20+
write_metrics: True
21+
gcs_metrics: False
22+
# If true save config to GCS in {base_output_directory}/{run_name}/
23+
save_config_to_gcs: False
24+
log_period: 100
25+
26+
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
27+
28+
# Flux params
29+
flux_name: "flux-dev"
30+
max_sequence_length: 512
31+
time_shift: True
32+
base_shift: 0.5
33+
max_shift: 1.15
34+
# offloads t5 encoder after text encoding to save memory.
35+
offload_encoders: True
36+
37+
38+
unet_checkpoint: ''
39+
revision: 'refs/pr/95'
40+
# This will convert the weights to this dtype.
41+
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
42+
weights_dtype: 'bfloat16'
43+
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
44+
activations_dtype: 'bfloat16'
45+
46+
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
47+
# Options are "DEFAULT", "HIGH", "HIGHEST"
48+
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
49+
# at the cost of time.
50+
precision: "DEFAULT"
51+
52+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
53+
# It must be True for multi-host.
54+
jit_initializers: True
55+
56+
# Set true to load weights from pytorch
57+
from_pt: True
58+
split_head_dim: True
59+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
60+
61+
flash_block_sizes: {}
62+
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
63+
# flash_block_sizes: {
64+
# "block_q" : 1536,
65+
# "block_kv_compute" : 1536,
66+
# "block_kv" : 1536,
67+
# "block_q_dkv" : 1536,
68+
# "block_kv_dkv" : 1536,
69+
# "block_kv_dkv_compute" : 1536,
70+
# "block_q_dq" : 1536,
71+
# "block_kv_dq" : 1536
72+
# }
73+
# GroupNorm groups
74+
norm_num_groups: 32
75+
76+
# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
77+
# else they will be loaded from pretrained_model_name_or_path
78+
train_new_unet: False
79+
80+
# train text_encoder - Currently not supported for SDXL
81+
train_text_encoder: False
82+
text_encoder_learning_rate: 4.25e-6
83+
84+
# https://arxiv.org/pdf/2305.08891.pdf
85+
snr_gamma: -1.0
86+
87+
timestep_bias: {
88+
# a value of later will increase the frequence of the model's final training steps.
89+
# none, earlier, later, range
90+
strategy: "none",
91+
# multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it.
92+
multiplier: 1.0,
93+
# when using strategy=range, the beginning (inclusive) timestep to bias.
94+
begin: 0,
95+
# when using strategy=range, the final step (inclusive) to bias.
96+
end: 1000,
97+
# portion of timesteps to bias.
98+
# 0.5 will bias one half of the timesteps. Value of strategy determines
99+
# whether the biased portions are in the earlier or later timesteps.
100+
portion: 0.25
101+
}
102+
103+
# Override parameters from checkpoints's scheduler.
104+
diffusion_scheduler_config: {
105+
_class_name: 'FlaxEulerDiscreteScheduler',
106+
prediction_type: 'epsilon',
107+
rescale_zero_terminal_snr: False,
108+
timestep_spacing: 'trailing'
109+
}
110+
111+
# Output directory
112+
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
113+
base_output_directory: ""
114+
115+
# Hardware
116+
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
117+
118+
# Parallelism
119+
mesh_axes: ['data', 'fsdp', 'tensor']
120+
121+
# batch : batch dimension of data and activations
122+
# hidden :
123+
# embed : attention qkv dense layer hidden dim named as embed
124+
# heads : attention head dim = num_heads * head_dim
125+
# length : attention sequence length
126+
# temb_in : dense.shape[0] of resnet dense before conv
127+
# out_c : dense.shape[1] of resnet dense before conv
128+
# out_channels : conv.shape[-1] activation
129+
# keep_1 : conv.shape[0] weight
130+
# keep_2 : conv.shape[1] weight
131+
# conv_in : conv.shape[2] weight
132+
# conv_out : conv.shape[-1] weight
133+
logical_axis_rules: [
134+
['batch', 'data'],
135+
['activation_batch', ['data','fsdp']],
136+
['activation_heads', 'tensor'],
137+
['activation_kv', 'tensor'],
138+
['mlp','tensor'],
139+
['embed','fsdp'],
140+
['heads', 'tensor'],
141+
['conv_batch', ['data','fsdp']],
142+
['out_channels', 'tensor'],
143+
['conv_out', 'fsdp'],
144+
]
145+
data_sharding: [['data', 'fsdp', 'tensor']]
146+
147+
# One axis for each parallelism type may hold a placeholder (-1)
148+
# value to auto-shard based on available slices and devices.
149+
# By default, product of the DCN axes should equal number of slices
150+
# and product of the ICI axes should equal number of devices per slice.
151+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
152+
dcn_fsdp_parallelism: -1
153+
dcn_tensor_parallelism: 1
154+
ici_data_parallelism: -1
155+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
156+
ici_tensor_parallelism: 1
157+
158+
# Dataset
159+
# Replace with dataset path or train_data_dir. One has to be set.
160+
dataset_name: 'diffusers/pokemon-gpt4-captions'
161+
train_split: 'train'
162+
dataset_type: 'tf'
163+
cache_latents_text_encoder_outputs: True
164+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
165+
# only apply to small dataset that fits in memory
166+
# prepare image latents and text encoder outputs
167+
# Reduce memory consumption and reduce step time during training
168+
# transformed dataset is saved at dataset_save_location
169+
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
170+
train_data_dir: ''
171+
dataset_config_name: ''
172+
jax_cache_dir: ''
173+
hf_data_dir: ''
174+
hf_train_files: ''
175+
hf_access_token: ''
176+
image_column: 'image'
177+
caption_column: 'text'
178+
resolution: 1024
179+
center_crop: False
180+
random_flip: False
181+
# If cache_latents_text_encoder_outputs is True
182+
# the num_proc is set to 1
183+
tokenize_captions_num_proc: 4
184+
transform_images_num_proc: 4
185+
reuse_example_batch: False
186+
enable_data_shuffling: True
187+
188+
# checkpoint every number of samples, -1 means don't checkpoint.
189+
checkpoint_every: -1
190+
# enables one replica to read the ckpt then broadcast to the rest
191+
enable_single_replica_ckpt_restoring: False
192+
193+
# Training loop
194+
learning_rate: 4.e-7
195+
scale_lr: False
196+
max_train_samples: -1
197+
# max_train_steps takes priority over num_train_epochs.
198+
max_train_steps: 200
199+
num_train_epochs: 1
200+
seed: 0
201+
output_dir: 'sdxl-model-finetuned'
202+
per_device_batch_size: 1
203+
204+
warmup_steps_fraction: 0.0
205+
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
206+
207+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
208+
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
209+
210+
# AdamW optimizer parameters
211+
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
212+
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
213+
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
214+
adam_weight_decay: 1.e-2 # AdamW Weight decay
215+
max_grad_norm: 1.0
216+
217+
enable_profiler: False
218+
# Skip first n steps for profiling, to omit things like compilation and to give
219+
# the iteration time a chance to stabilize.
220+
skip_first_n_steps_for_profiler: 5
221+
profiler_steps: 10
222+
223+
# Generation parameters
224+
prompt: "A magical castle in the middle of a forest, artistic drawing"
225+
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
226+
negative_prompt: "purple, red"
227+
do_classifier_free_guidance: True
228+
guidance_scale: 3.5
229+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
230+
guidance_rescale: 0.0
231+
num_inference_steps: 50
232+
233+
# SDXL Lightning parameters
234+
lightning_from_pt: True
235+
# Empty or "ByteDance/SDXL-Lightning" to enable lightning.
236+
lightning_repo: ""
237+
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
238+
lightning_ckpt: ""
239+
240+
# LoRA parameters
241+
# Values are lists to support multiple LoRA loading during inference in the future.
242+
lora_config: {
243+
lora_model_name_or_path: [],
244+
weight_name: [],
245+
adapter_name: [],
246+
scale: [],
247+
from_pt: []
248+
}
249+
# Ex with values:
250+
# lora_config : {
251+
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
252+
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
253+
# adapter_name: ["hyper-sdxl"],
254+
# scale: [0.7],
255+
# from_pt: [True]
256+
# }
257+
258+
enable_mllog: False
259+
260+
#controlnet
261+
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
262+
controlnet_from_pt: True
263+
controlnet_conditioning_scale: 0.5
264+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
265+
quantization: ''
266+
# Shard the range finding operation for quantization. By default this is set to number of slices.
267+
quantization_local_shard_count: -1
268+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
269+

src/maxdiffusion/configuration_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ def extract_init_dict(cls, config_dict, **kwargs):
464464
# remove flax internal keys
465465
if hasattr(cls, "_flax_internal_args"):
466466
for arg in cls._flax_internal_args:
467-
expected_keys.remove(arg)
467+
if arg in expected_keys:
468+
expected_keys.remove(arg)
468469

469470
# 2. Remove attributes that cannot be expected from expected config attributes
470471
# remove keys to be ignored

src/maxdiffusion/image_processor.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,53 @@
3636
]
3737

3838

39+
def is_valid_image(image) -> bool:
40+
r"""
41+
Checks if the input is a valid image.
42+
43+
A valid image can be:
44+
- A `PIL.Image.Image`.
45+
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
46+
47+
Args:
48+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
49+
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
50+
51+
Returns:
52+
`bool`:
53+
`True` if the input is a valid image, `False` otherwise.
54+
"""
55+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
56+
57+
58+
def is_valid_image_imagelist(images):
59+
r"""
60+
Checks if the input is a valid image or list of images.
61+
62+
The input can be one of the following formats:
63+
- A 4D tensor or numpy array (batch of images).
64+
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
65+
`torch.Tensor`.
66+
- A list of valid images.
67+
68+
Args:
69+
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
70+
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
71+
images.
72+
73+
Returns:
74+
`bool`:
75+
`True` if the input is valid, `False` otherwise.
76+
"""
77+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
78+
return True
79+
elif is_valid_image(images):
80+
return True
81+
elif isinstance(images, list):
82+
return all(is_valid_image(image) for image in images)
83+
return False
84+
85+
3986
class VaeImageProcessor(ConfigMixin):
4087
"""
4188
Image processor for VAE.

0 commit comments

Comments
 (0)