-
Notifications
You must be signed in to change notification settings - Fork 46
Expand file tree
/
Copy pathcosmos_t2v_inference.py
More file actions
161 lines (134 loc) · 8.55 KB
/
cosmos_t2v_inference.py
File metadata and controls
161 lines (134 loc) · 8.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
import json
import os
import math
from copy import deepcopy
import torch
from diffusers import CosmosTextToWorldPipeline
from diffusers.utils import export_to_video
from termcolor import colored
from dataloader import load_prompt_or_image
from svg.models.cosmos.inference import replace_cosmos_attention
from svg.utils.seed import seed_everything
from svg.timer import print_operator_log_data
from svg.logger import logger
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate video from text prompt using Wan-Diffuser")
parser.add_argument("--model_id", type=str, default="nvidia/Cosmos-1.0-Diffusion-14B-Text2World", help="Model ID to use for generation")
parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation")
parser.add_argument("--negative_prompt", type=str, default=None, help="Negative text prompt to avoid certain features")
parser.add_argument("--prompt_source", type=str, default="prompt", choices=["prompt", "T2V_Wan_VBench", "T2V_Xingyang_VBench"], help="Source of the prompt")
parser.add_argument("--prompt_idx", type=int, default=0, help="Index of the prompt")
parser.add_argument("--height", type=int, default=704, help="Height of the generated video")
parser.add_argument("--width", type=int, default=1280, help="Width of the generated video")
parser.add_argument("--num_frames", type=int, default=121, help="Number of frames in the generated video")
parser.add_argument("--num_inference_steps", type=int, default=35, help="Number of denoising steps in the generated video")
parser.add_argument("--output_file", type=str, default="output.mp4", help="Output video file name")
parser.add_argument("--logging_file", type=str, default=None, help="Path to the logging file.")
parser.add_argument("--seed", type=int, default=0, help="Random seed for generation")
parser.add_argument("--skip_existing", action="store_true", help="Skip generating existing output files")
parser.add_argument("--pattern", type=str, default="dense", choices=["SVG", "dense", "SAP"])
parser.add_argument("--first_layers_fp", type=float, default=0.025, help="Only works for best config. Leave the 0, 1, 2, 40, 41 layers in FP")
parser.add_argument("--first_times_fp", type=float, default=0.075, help="Only works for best config. Leave the first 10% timestep in FP")
# SVG related
parser.add_argument("--num_sampled_rows", type=int, default=64, help="The number of sampled rows")
parser.add_argument("--sample_mse_max_row", type=int, default=10000, help="The maximum number of rows in attention mask. Prevent OOM.")
parser.add_argument("--sparsity", type=float, default=0.25, help="The sparsity of the striped attention pattern. Accepts one or two float values.")
# SAP related
parser.add_argument("--num_q_centroids", "--qc", type=int, default=50, help="Number of query centroids for KMEANS_BLOCK.")
parser.add_argument("--num_k_centroids", "--kc", type=int, default=200, help="Number of key centroids for KMEANS_BLOCK.")
parser.add_argument("--top_p_kmeans", type=float, default=0.9, help="Top-p threshold for block selection in KMEANS_BLOCK.")
parser.add_argument("--min_kc_ratio", type=float, default=0, help="At least this proportion of key blocks to keep per query block in KMEANS_BLOCK.")
parser.add_argument("--kmeans_iter_init", type=int, default=0, help="Number of KMeans iterations for initialization in KMEANS_BLOCK.")
parser.add_argument("--kmeans_iter_step", type=int, default=0, help="Number of KMeans iterations for other diffusion steps in KMEANS_BLOCK.")
args = parser.parse_args()
seed_everything(args.seed)
# In some cases it will raise RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR
torch.backends.cuda.preferred_linalg_library(backend="magma")
if args.skip_existing:
if os.path.exists(args.output_file):
logger.info(f"Output file {args.output_file} already exists. Skipping generation.")
exit(0)
#########################################################
# Load the model
#########################################################
# Available models: nvidia/Cosmos-1.0-Diffusion-14B-Text2World, nvidia/Cosmos-1.0-Diffusion-7B-Text2World
model_id = args.model_id
pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.to("cuda")
config = pipe.transformer.config
#########################################################
# Translate the percentage of warmup of layers and timesteps to the actual layers and timesteps
#########################################################
ref_scheduler = deepcopy(pipe.scheduler)
ref_scheduler.set_timesteps(args.num_inference_steps)
ref_timesteps = ref_scheduler.timesteps
num_fp_timesteps = math.floor(args.first_times_fp * args.num_inference_steps)
num_fp_layers = math.floor(args.first_layers_fp * config.num_layers)
if num_fp_timesteps > 0:
args.first_times_fp = ref_scheduler.timesteps[num_fp_timesteps - 1] - 0.01
else:
args.first_times_fp = 1001 # 1000 is the first timestep
args.first_layers_fp = num_fp_layers
logger.info(f"Warmup of Timesteps: {num_fp_timesteps} / {args.num_inference_steps} || {args.first_times_fp} / 1000 use FP")
logger.info(f"Warmup of Layers: {num_fp_layers} / {config.num_layers} use FP")
#########################################################
# Load the prompt
#########################################################
args.prompt, _ = load_prompt_or_image(args.prompt_source, args.prompt_idx, args.prompt, None)
if args.prompt is None:
logger.info(colored("Using default prompt", "red"))
args.prompt = "A cat walks on the grass, realistic"
if args.negative_prompt is None:
args.negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
print("=" * 20 + " Prompts " + "=" * 20)
print(f"Prompt: {args.prompt}\n\n" + f"Negative Prompt: {args.negative_prompt}")
#########################################################
# Replace the attention & time logger
#########################################################
if args.pattern == "SVG":
replace_cosmos_attention(
pipe,
args.height,
args.width,
args.num_frames,
first_layers_fp=args.first_layers_fp,
first_times_fp=args.first_times_fp,
pattern=args.pattern,
# SVG specific
num_sampled_rows=args.num_sampled_rows,
sample_mse_max_row=args.sample_mse_max_row,
sparsity=args.sparsity,
)
elif args.pattern == "SAP":
replace_cosmos_attention(
pipe,
args.height,
args.width,
args.num_frames,
first_layers_fp=args.first_layers_fp,
first_times_fp=args.first_times_fp,
pattern=args.pattern,
# SAP specific
num_q_centroids=args.num_q_centroids,
num_k_centroids=args.num_k_centroids,
top_p_kmeans=args.top_p_kmeans,
min_kc_ratio=args.min_kc_ratio,
logging_file=args.logging_file,
kmeans_iter_init=args.kmeans_iter_init,
kmeans_iter_step=args.kmeans_iter_step,
)
# Print time logger
for block in pipe.transformer.transformer_blocks:
block.register_forward_hook(print_operator_log_data)
#########################################################
# Generate the video
#########################################################
output = pipe(
prompt=args.prompt, negative_prompt=args.negative_prompt, height=args.height, width=args.width, num_frames=args.num_frames, guidance_scale=7.0, num_inference_steps=args.num_inference_steps
).frames[0]
# Create parent directory for output file if it doesn't exist
output_dir = os.path.dirname(args.output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
export_to_video(output, args.output_file, fps=30)