Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ def make_tfrecord_iterator(
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""

# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.

# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)

if (
config.cache_latents_text_encoder_outputs
and os.path.isdir(config.dataset_save_location)
and is_dataset_dir_valid
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
):
Expand Down
34 changes: 21 additions & 13 deletions src/maxdiffusion/models/wan/wan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def rename_for_custom_trasformer(key):
return renamed_pt_key


def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
def load_fusionx_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
):
device = jax.local_devices(backend=device)[0]
with jax.default_device(device):
if hf_download:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors")
Expand Down Expand Up @@ -97,7 +99,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
Expand All @@ -107,8 +109,10 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
return flax_state_dict


def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
def load_causvid_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
):
device = jax.local_devices(backend=device)[0]
with jax.default_device(device):
if hf_download:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt")
Expand Down Expand Up @@ -145,7 +149,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
Expand All @@ -155,18 +159,22 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
return flax_state_dict


def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
def load_wan_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
):

if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
else:
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)


def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
def load_base_wan_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
):
device = jax.local_devices(backend=device)[0]
subfolder = "transformer"
filename = "diffusion_pytorch_model.safetensors.index.json"
local_files = False
Expand Down Expand Up @@ -237,7 +245,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
# This helps with loading sharded weights directly into the accelerators without fist copying them
# all to one device and then distributing them, thus using low HBM memory.
params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu")
params = load_wan_transformer(
config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"]
)
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding[path].value
Expand Down
Loading