Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffsynth/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
(None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
(None, "8cf5720f1d99f2d3d9f4d059c99f7e25", ["wan_video_dit"], [WanModel], "civitai"),
(None, "0e2ab7dec4711919374f3d7ffdea90be", ["wan_video_dit"], [WanModel], "civitai"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is trailing whitespace on this line which should be removed to maintain code style consistency.

Suggested change
(None, "0e2ab7dec4711919374f3d7ffdea90be", ["wan_video_dit"], [WanModel], "civitai"),
(None, "0e2ab7dec4711919374f3d7ffdea90be", ["wan_video_dit"], [WanModel], "civitai"),

(None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
Expand Down
6 changes: 3 additions & 3 deletions diffsynth/models/wan_video_camera_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from typing_extensions import Literal

class SimpleAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1):
super(SimpleAdapter, self).__init__()

# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment is now outdated as the downscale_factor is a parameter. Please update the comment to reflect this change for better code clarity.

Suggested change
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
# Pixel Unshuffle: reduce spatial dimensions by a configurable factor

self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor)

# Convolution: reduce spatial dimensions by a factor
# of 2 (without overlap)
self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0)

# Residual blocks for feature extraction
self.residual_blocks = nn.Sequential(
Expand Down
48 changes: 46 additions & 2 deletions diffsynth/models/wan_video_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def __init__(
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
downscale_factor_control_adapter: int = 8,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
Expand Down Expand Up @@ -328,11 +329,13 @@ def __init__(
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if has_ref_conv:
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
self.ref_conv = nn.Conv2d(out_dim, dim, kernel_size=(2, 2), stride=(2, 2))
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
if add_control_adapter:
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim,
kernel_size=patch_size[1:], stride=patch_size[1:],
downscale_factor=downscale_factor_control_adapter)
else:
self.control_adapter = None

Expand Down Expand Up @@ -750,6 +753,47 @@ def from_civitai(self, state_dict):
"in_dim_control_adapter": 24,
"require_clip_embedding": False,
}
elif hash_state_dict_keys(state_dict) == "8cf5720f1d99f2d3d9f4d059c99f7e25":
# Wan2.2-Fun-5B-Control
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 148,
"dim": 3072,
"ffn_dim": 14336,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 48,
"num_heads": 24,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": True,
"require_clip_embedding": False,
"seperated_timestep": True,
"require_vae_embedding": True,
"fuse_vae_embedding_in_latents": True,
}
elif hash_state_dict_keys(state_dict) == "0e2ab7dec4711919374f3d7ffdea90be":
# Wan2.2-Fun-5B-Control-Camera
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 100,
"dim": 3072,
"ffn_dim": 14336,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 48,
"num_heads": 24,
"num_layers": 30,
"eps": 1e-6,
"has_ref_conv": False,
"require_clip_embedding": False,
"seperated_timestep": True,
"add_control_adapter": True,
"downscale_factor_control_adapter": 16,
"in_dim_control_adapter": 24,
}
else:
config = {}
return state_dict, config
16 changes: 9 additions & 7 deletions diffsynth/pipelines/wan_video_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
clip_context = pipe.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
Expand All @@ -612,7 +612,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)

msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2))
msk = msk.transpose(1, 2)[0]

y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
Expand Down Expand Up @@ -659,7 +659,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
Expand All @@ -669,7 +669,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)

msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2))
msk = msk.transpose(1, 2)[0]

y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
Expand Down Expand Up @@ -719,9 +719,11 @@ def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, wid
y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1]
if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
height_division_factor, width_division_factor = pipe.height_division_factor // 2, pipe.width_division_factor // 2
y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//height_division_factor, width//width_division_factor), dtype=pipe.torch_dtype, device=pipe.device)
else:
y = y[:, -y_dim:]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This blank line is unnecessary and can be removed to improve code conciseness.

y = torch.concat([control_latents, y], dim=1)
return {"clip_feature": clip_feature, "y": y}

Expand Down Expand Up @@ -787,10 +789,10 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_cont
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2))
msk = msk.transpose(1, 2)[0]
y = torch.cat([msk,y])
y = y.unsqueeze(0)
Expand Down
41 changes: 41 additions & 0 deletions examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control-Camera.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from diffsynth import save_video,VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from PIL import Image
from modelscope import dataset_snapshot_download

pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()

dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This f-string does not contain any expressions and can be converted to a regular string.

Suggested change
allow_file_pattern=f"data/examples/wan/input_image.jpg"
allow_file_pattern="data/examples/wan/input_image.jpg"

)
input_image = Image.open("data/examples/wan/input_image.jpg")

video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
seed=0, tiled=True,
input_image=input_image,
camera_control_direction="Left", camera_control_speed=0.01,
)
save_video(video, "video_left.mp4", fps=15, quality=5)

video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
seed=0, tiled=True,
input_image=input_image,
camera_control_direction="Up", camera_control_speed=0.01,
)
save_video(video, "video_up.mp4", fps=15, quality=5)
Comment on lines +25 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The two calls to pipe are almost identical, with only camera_control_direction and the output filename changing. This duplicated code can be refactored into a loop to improve readability and maintainability.

common_args = {
    "prompt": "一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
    "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
    "seed": 0, "tiled": True,
    "input_image": input_image,
    "camera_control_speed": 0.01,
}

for direction in ["Left", "Up"]:
    video = pipe(
        **common_args,
        camera_control_direction=direction,
    )
    save_video(video, f"video_{direction.lower()}.mp4", fps=15, quality=5)

34 changes: 34 additions & 0 deletions examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
from diffsynth import save_video,VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from PIL import Image
from modelscope import dataset_snapshot_download

pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()

dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
)

# Control video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
video = pipe(
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
control_video=control_video, reference_image=reference_image,
height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video.mp4", fps=15, quality=5)