From cad5e9207634cd03a017b915b2915590a73f548f Mon Sep 17 00:00:00 2001 From: Feiteng Date: Fri, 19 Sep 2025 14:57:48 +0000 Subject: [PATCH] support Wan2.2-Fun-5B-Control[-Camera] inference --- diffsynth/configs/model_config.py | 2 + .../models/wan_video_camera_controller.py | 6 +-- diffsynth/models/wan_video_dit.py | 48 ++++++++++++++++++- diffsynth/pipelines/wan_video_new.py | 16 ++++--- .../Wan2.2-Fun-5B-Control-Camera.py | 41 ++++++++++++++++ .../model_inference/Wan2.2-Fun-5B-Control.py | 34 +++++++++++++ 6 files changed, 135 insertions(+), 12 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control-Camera.py create mode 100644 examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 43fe84b0..51db5199 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -153,6 +153,8 @@ (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), (None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"), + (None, "8cf5720f1d99f2d3d9f4d059c99f7e25", ["wan_video_dit"], [WanModel], "civitai"), + (None, "0e2ab7dec4711919374f3d7ffdea90be", ["wan_video_dit"], [WanModel], "civitai"), (None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), diff --git a/diffsynth/models/wan_video_camera_controller.py b/diffsynth/models/wan_video_camera_controller.py index 45a44ee6..cf647809 100644 --- a/diffsynth/models/wan_video_camera_controller.py +++ b/diffsynth/models/wan_video_camera_controller.py @@ -6,15 +6,15 @@ from typing_extensions import Literal class SimpleAdapter(nn.Module): - def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): + def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1): super(SimpleAdapter, self).__init__() # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 - self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor) # Convolution: reduce spatial dimensions by a factor # of 2 (without overlap) - self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) + self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0) # Residual blocks for feature extraction self.residual_blocks = nn.Sequential( diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 1a54728f..52ef2e17 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -287,6 +287,7 @@ def __init__( has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, + downscale_factor_control_adapter: int = 8, seperated_timestep: bool = False, require_vae_embedding: bool = True, require_clip_embedding: bool = True, @@ -328,11 +329,13 @@ def __init__( if has_image_input: self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 if has_ref_conv: - self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.ref_conv = nn.Conv2d(out_dim, dim, kernel_size=(2, 2), stride=(2, 2)) self.has_image_pos_emb = has_image_pos_emb self.has_ref_conv = has_ref_conv if add_control_adapter: - self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, + kernel_size=patch_size[1:], stride=patch_size[1:], + downscale_factor=downscale_factor_control_adapter) else: self.control_adapter = None @@ -750,6 +753,47 @@ def from_civitai(self, state_dict): "in_dim_control_adapter": 24, "require_clip_embedding": False, } + elif hash_state_dict_keys(state_dict) == "8cf5720f1d99f2d3d9f4d059c99f7e25": + # Wan2.2-Fun-5B-Control + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 148, + "dim": 3072, + "ffn_dim": 14336, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 48, + "num_heads": 24, + "num_layers": 30, + "eps": 1e-6, + "has_ref_conv": True, + "require_clip_embedding": False, + "seperated_timestep": True, + "require_vae_embedding": True, + "fuse_vae_embedding_in_latents": True, + } + elif hash_state_dict_keys(state_dict) == "0e2ab7dec4711919374f3d7ffdea90be": + # Wan2.2-Fun-5B-Control-Camera + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 100, + "dim": 3072, + "ffn_dim": 14336, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 48, + "num_heads": 24, + "num_layers": 30, + "eps": 1e-6, + "has_ref_conv": False, + "require_clip_embedding": False, + "seperated_timestep": True, + "add_control_adapter": True, + "downscale_factor_control_adapter": 16, + "in_dim_control_adapter": 24, + } else: config = {} return state_dict, config diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 660a38e7..cb404107 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -600,7 +600,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) clip_context = pipe.image_encoder.encode_image([image]) - msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device) msk[:, 1:] = 0 if end_image is not None: end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) @@ -612,7 +612,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2)) msk = msk.transpose(1, 2)[0] y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] @@ -659,7 +659,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) - msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device) msk[:, 1:] = 0 if end_image is not None: end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) @@ -669,7 +669,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2)) msk = msk.transpose(1, 2)[0] y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] @@ -719,9 +719,11 @@ def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, wid y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] if clip_feature is None or y is None: clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) - y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + height_division_factor, width_division_factor = pipe.height_division_factor // 2, pipe.width_division_factor // 2 + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//height_division_factor, width//width_division_factor), dtype=pipe.torch_dtype, device=pipe.device) else: y = y[:, -y_dim:] + y = torch.concat([control_latents, y], dim=1) return {"clip_feature": clip_feature, "y": y} @@ -787,10 +789,10 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_cont vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] y = y.to(dtype=pipe.torch_dtype, device=pipe.device) - msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device) msk[:, 1:] = 0 msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2)) msk = msk.transpose(1, 2)[0] y = torch.cat([msk,y]) y = y.unsqueeze(0) diff --git a/examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control-Camera.py new file mode 100644 index 00000000..a5da998f --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control-Camera.py @@ -0,0 +1,41 @@ +import torch +from diffsynth import save_video,VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control.py b/examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control.py new file mode 100644 index 00000000..b9a10f17 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Fun-5B-Control.py @@ -0,0 +1,34 @@ +import torch +from diffsynth import save_video,VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5)