-
Notifications
You must be signed in to change notification settings - Fork 962
support Wan2.2-Fun-5B-Control[-Camera] inference #930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,15 +6,15 @@ | |
from typing_extensions import Literal | ||
|
||
class SimpleAdapter(nn.Module): | ||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): | ||
def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1): | ||
super(SimpleAdapter, self).__init__() | ||
|
||
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) | ||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor) | ||
|
||
# Convolution: reduce spatial dimensions by a factor | ||
# of 2 (without overlap) | ||
self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) | ||
self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0) | ||
|
||
# Residual blocks for feature extraction | ||
self.residual_blocks = nn.Sequential( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -600,7 +600,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he | |
pipe.load_models_to_device(self.onload_model_names) | ||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) | ||
clip_context = pipe.image_encoder.encode_image([image]) | ||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) | ||
msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device) | ||
msk[:, 1:] = 0 | ||
if end_image is not None: | ||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) | ||
|
@@ -612,7 +612,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he | |
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) | ||
|
||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | ||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | ||
msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2)) | ||
msk = msk.transpose(1, 2)[0] | ||
|
||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] | ||
|
@@ -659,7 +659,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he | |
return {} | ||
pipe.load_models_to_device(self.onload_model_names) | ||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) | ||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) | ||
msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device) | ||
msk[:, 1:] = 0 | ||
if end_image is not None: | ||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) | ||
|
@@ -669,7 +669,7 @@ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, he | |
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) | ||
|
||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | ||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | ||
msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2)) | ||
msk = msk.transpose(1, 2)[0] | ||
|
||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] | ||
|
@@ -719,9 +719,11 @@ def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, wid | |
y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] | ||
if clip_feature is None or y is None: | ||
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) | ||
y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) | ||
height_division_factor, width_division_factor = pipe.height_division_factor // 2, pipe.width_division_factor // 2 | ||
y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//height_division_factor, width//width_division_factor), dtype=pipe.torch_dtype, device=pipe.device) | ||
else: | ||
y = y[:, -y_dim:] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
y = torch.concat([control_latents, y], dim=1) | ||
return {"clip_feature": clip_feature, "y": y} | ||
|
||
|
@@ -787,10 +789,10 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_cont | |
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) | ||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] | ||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | ||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) | ||
msk = torch.ones(1, num_frames, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2), device=pipe.device) | ||
msk[:, 1:] = 0 | ||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | ||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | ||
msk = msk.view(1, msk.shape[1] // 4, 4, height//(pipe.height_division_factor//2), width//(pipe.width_division_factor//2)) | ||
msk = msk.transpose(1, 2)[0] | ||
y = torch.cat([msk,y]) | ||
y = y.unsqueeze(0) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
from diffsynth import save_video,VideoData | ||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig | ||
from PIL import Image | ||
from modelscope import dataset_snapshot_download | ||
|
||
pipe = WanVideoPipeline.from_pretrained( | ||
torch_dtype=torch.bfloat16, | ||
device="cuda", | ||
model_configs=[ | ||
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu"), | ||
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), | ||
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control-Camera", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), | ||
], | ||
) | ||
pipe.enable_vram_management() | ||
|
||
dataset_snapshot_download( | ||
dataset_id="DiffSynth-Studio/examples_in_diffsynth", | ||
local_dir="./", | ||
allow_file_pattern=f"data/examples/wan/input_image.jpg" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
) | ||
input_image = Image.open("data/examples/wan/input_image.jpg") | ||
|
||
video = pipe( | ||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", | ||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", | ||
seed=0, tiled=True, | ||
input_image=input_image, | ||
camera_control_direction="Left", camera_control_speed=0.01, | ||
) | ||
save_video(video, "video_left.mp4", fps=15, quality=5) | ||
|
||
video = pipe( | ||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", | ||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", | ||
seed=0, tiled=True, | ||
input_image=input_image, | ||
camera_control_direction="Up", camera_control_speed=0.01, | ||
) | ||
save_video(video, "video_up.mp4", fps=15, quality=5) | ||
Comment on lines
+25
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The two calls to common_args = {
"prompt": "一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"seed": 0, "tiled": True,
"input_image": input_image,
"camera_control_speed": 0.01,
}
for direction in ["Left", "Up"]:
video = pipe(
**common_args,
camera_control_direction=direction,
)
save_video(video, f"video_{direction.lower()}.mp4", fps=15, quality=5) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
from diffsynth import save_video,VideoData | ||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig | ||
from PIL import Image | ||
from modelscope import dataset_snapshot_download | ||
|
||
pipe = WanVideoPipeline.from_pretrained( | ||
torch_dtype=torch.bfloat16, | ||
device="cuda", | ||
model_configs=[ | ||
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="diffusion_pytorch_model.safetensors", offload_device="cpu"), | ||
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), | ||
ModelConfig(model_id="PAI/Wan2.2-Fun-5B-Control", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), | ||
], | ||
) | ||
pipe.enable_vram_management() | ||
|
||
dataset_snapshot_download( | ||
dataset_id="DiffSynth-Studio/examples_in_diffsynth", | ||
local_dir="./", | ||
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] | ||
) | ||
|
||
# Control video | ||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) | ||
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) | ||
video = pipe( | ||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", | ||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", | ||
control_video=control_video, reference_image=reference_image, | ||
height=832, width=576, num_frames=49, | ||
seed=1, tiled=True | ||
) | ||
save_video(video, "video.mp4", fps=15, quality=5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is trailing whitespace on this line which should be removed to maintain code style consistency.