Skip to content

Commit 0d9afba

Browse files
Multi res support (#159)
* fixes accidental wrong merge. * precompile generate functions with different dimensions. * iterate over different resolutions and store precompiled functions in dict. * adds a new inference file to show how to precompile for different resolutions. * formatting * remove unused dependencies. * update config with flux names * revert to torch 2.5.1 due to compatibility with torchvision. --------- Co-authored-by: Juan Acevedo <[email protected]>
1 parent d08b8d6 commit 0d9afba

File tree

5 files changed

+855
-4
lines changed

5 files changed

+855
-4
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ absl-py
66
datasets
77
flax>=0.10.2
88
optax>=0.2.3
9-
torch==2.6.0
9+
torch==2.5.1
1010
torchvision==0.20.1
1111
ftfy
1212
tensorboard>=2.17.0

requirements_with_jax_stable_stack.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ tensorboardx==2.6.2.2
2929
tensorflow>=2.17.0
3030
tensorflow-datasets>=4.9.6
3131
tokenizers==0.21.0
32-
torch==2.6.0
32+
torch==2.5.1
3333
torchvision==0.20.1
3434
transformers==4.48.1
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This sentinel is a reminder to choose a real run name.
16+
run_name: ''
17+
18+
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
19+
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
20+
write_metrics: True
21+
gcs_metrics: False
22+
# If true save config to GCS in {base_output_directory}/{run_name}/
23+
save_config_to_gcs: False
24+
log_period: 100
25+
26+
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
27+
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
28+
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
29+
30+
# Flux params
31+
flux_name: "flux-dev"
32+
max_sequence_length: 512
33+
time_shift: True
34+
base_shift: 0.5
35+
max_shift: 1.15
36+
# offloads t5 encoder after text encoding to save memory.
37+
offload_encoders: True
38+
39+
40+
unet_checkpoint: ''
41+
revision: 'refs/pr/95'
42+
# This will convert the weights to this dtype.
43+
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
44+
weights_dtype: 'bfloat16'
45+
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
46+
activations_dtype: 'bfloat16'
47+
48+
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
49+
# Options are "DEFAULT", "HIGH", "HIGHEST"
50+
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
51+
# at the cost of time.
52+
precision: "DEFAULT"
53+
54+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
55+
# It must be True for multi-host.
56+
jit_initializers: True
57+
58+
# Set true to load weights from pytorch
59+
from_pt: True
60+
split_head_dim: True
61+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
62+
63+
#flash_block_sizes: {}
64+
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
65+
flash_block_sizes: {
66+
"block_q" : 1536,
67+
"block_kv_compute" : 1536,
68+
"block_kv" : 1536,
69+
"block_q_dkv" : 1536,
70+
"block_kv_dkv" : 1536,
71+
"block_kv_dkv_compute" : 1536,
72+
"block_q_dq" : 1536,
73+
"block_kv_dq" : 1536
74+
}
75+
# GroupNorm groups
76+
norm_num_groups: 32
77+
78+
# If train_new_flux, unet weights will be randomly initialized to train the unet from scratch
79+
# else they will be loaded from pretrained_model_name_or_path
80+
train_new_flux: False
81+
82+
# train text_encoder - Currently not supported for SDXL
83+
train_text_encoder: False
84+
text_encoder_learning_rate: 4.25e-6
85+
86+
# https://arxiv.org/pdf/2305.08891.pdf
87+
snr_gamma: -1.0
88+
89+
timestep_bias: {
90+
# a value of later will increase the frequence of the model's final training steps.
91+
# none, earlier, later, range
92+
strategy: "none",
93+
# multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it.
94+
multiplier: 1.0,
95+
# when using strategy=range, the beginning (inclusive) timestep to bias.
96+
begin: 0,
97+
# when using strategy=range, the final step (inclusive) to bias.
98+
end: 1000,
99+
# portion of timesteps to bias.
100+
# 0.5 will bias one half of the timesteps. Value of strategy determines
101+
# whether the biased portions are in the earlier or later timesteps.
102+
portion: 0.25
103+
}
104+
105+
# Override parameters from checkpoints's scheduler.
106+
diffusion_scheduler_config: {
107+
_class_name: 'FlaxEulerDiscreteScheduler',
108+
prediction_type: 'epsilon',
109+
rescale_zero_terminal_snr: False,
110+
timestep_spacing: 'trailing'
111+
}
112+
113+
# Output directory
114+
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
115+
base_output_directory: ""
116+
117+
# Hardware
118+
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
119+
120+
# Parallelism
121+
mesh_axes: ['data', 'fsdp', 'tensor']
122+
123+
# batch : batch dimension of data and activations
124+
# hidden :
125+
# embed : attention qkv dense layer hidden dim named as embed
126+
# heads : attention head dim = num_heads * head_dim
127+
# length : attention sequence length
128+
# temb_in : dense.shape[0] of resnet dense before conv
129+
# out_c : dense.shape[1] of resnet dense before conv
130+
# out_channels : conv.shape[-1] activation
131+
# keep_1 : conv.shape[0] weight
132+
# keep_2 : conv.shape[1] weight
133+
# conv_in : conv.shape[2] weight
134+
# conv_out : conv.shape[-1] weight
135+
logical_axis_rules: [
136+
['batch', 'data'],
137+
['activation_batch', ['data','fsdp']],
138+
['activation_heads', 'tensor'],
139+
['activation_kv', 'tensor'],
140+
# ['embed','fsdp'],
141+
['mlp',['fsdp','tensor']],
142+
['heads', 'tensor'],
143+
['conv_batch', ['data','fsdp']],
144+
['out_channels', 'tensor'],
145+
['conv_out', 'fsdp'],
146+
]
147+
data_sharding: [['data', 'fsdp', 'tensor']]
148+
149+
# One axis for each parallelism type may hold a placeholder (-1)
150+
# value to auto-shard based on available slices and devices.
151+
# By default, product of the DCN axes should equal number of slices
152+
# and product of the ICI axes should equal number of devices per slice.
153+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
154+
dcn_fsdp_parallelism: -1
155+
dcn_tensor_parallelism: 1
156+
ici_data_parallelism: -1
157+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
158+
ici_tensor_parallelism: 1
159+
160+
# Dataset
161+
# Replace with dataset path or train_data_dir. One has to be set.
162+
dataset_name: 'diffusers/pokemon-gpt4-captions'
163+
train_split: 'train'
164+
dataset_type: 'tf'
165+
cache_latents_text_encoder_outputs: True
166+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
167+
# only apply to small dataset that fits in memory
168+
# prepare image latents and text encoder outputs
169+
# Reduce memory consumption and reduce step time during training
170+
# transformed dataset is saved at dataset_save_location
171+
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
172+
train_data_dir: ''
173+
dataset_config_name: ''
174+
jax_cache_dir: ''
175+
hf_data_dir: ''
176+
hf_train_files: ''
177+
hf_access_token: ''
178+
image_column: 'image'
179+
caption_column: 'text'
180+
resolution: 1024
181+
center_crop: False
182+
random_flip: False
183+
# If cache_latents_text_encoder_outputs is True
184+
# the num_proc is set to 1
185+
tokenize_captions_num_proc: 4
186+
transform_images_num_proc: 4
187+
reuse_example_batch: False
188+
enable_data_shuffling: True
189+
190+
# checkpoint every number of samples, -1 means don't checkpoint.
191+
checkpoint_every: -1
192+
# enables one replica to read the ckpt then broadcast to the rest
193+
enable_single_replica_ckpt_restoring: False
194+
195+
# Training loop
196+
learning_rate: 4.e-7
197+
scale_lr: False
198+
max_train_samples: -1
199+
# max_train_steps takes priority over num_train_epochs.
200+
max_train_steps: 200
201+
num_train_epochs: 1
202+
seed: 0
203+
output_dir: 'sdxl-model-finetuned'
204+
per_device_batch_size: 1
205+
206+
warmup_steps_fraction: 0.0
207+
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
208+
209+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
210+
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
211+
212+
# AdamW optimizer parameters
213+
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
214+
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
215+
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
216+
adam_weight_decay: 1.e-2 # AdamW Weight decay
217+
max_grad_norm: 1.0
218+
219+
enable_profiler: False
220+
# Skip first n steps for profiling, to omit things like compilation and to give
221+
# the iteration time a chance to stabilize.
222+
skip_first_n_steps_for_profiler: 5
223+
profiler_steps: 10
224+
225+
# Generation parameters
226+
prompt: "A magical castle in the middle of a forest, artistic drawing"
227+
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
228+
negative_prompt: "purple, red"
229+
do_classifier_free_guidance: True
230+
guidance_scale: 3.5
231+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
232+
guidance_rescale: 0.0
233+
num_inference_steps: 50
234+
235+
# SDXL Lightning parameters
236+
lightning_from_pt: True
237+
# Empty or "ByteDance/SDXL-Lightning" to enable lightning.
238+
lightning_repo: ""
239+
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
240+
lightning_ckpt: ""
241+
242+
# LoRA parameters
243+
# Values are lists to support multiple LoRA loading during inference in the future.
244+
lora_config: {
245+
lora_model_name_or_path: [],
246+
weight_name: [],
247+
adapter_name: [],
248+
scale: [],
249+
from_pt: []
250+
}
251+
# Ex with values:
252+
# lora_config : {
253+
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
254+
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
255+
# adapter_name: ["hyper-sdxl"],
256+
# scale: [0.7],
257+
# from_pt: [True]
258+
# }
259+
260+
enable_mllog: False
261+
262+
#controlnet
263+
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
264+
controlnet_from_pt: True
265+
controlnet_conditioning_scale: 0.5
266+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
267+
quantization: ''
268+
# Shard the range finding operation for quantization. By default this is set to number of slices.
269+
quantization_local_shard_count: -1
270+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
271+

src/maxdiffusion/generate_flux.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def get_t5_prompt_embeds(
226226

227227
prompt = [prompt] if isinstance(prompt, str) else prompt
228228
batch_size = len(prompt)
229-
230229
text_inputs = tokenizer(
231230
prompt,
232231
truncation=True,
@@ -244,7 +243,6 @@ def get_t5_prompt_embeds(
244243
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
245244
prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1))
246245
prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1))
247-
248246
return prompt_embeds
249247

250248

0 commit comments

Comments
 (0)