Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,9 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
pe=pe, mask=attn_mask, transformer_options=transformer_options)

txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]

# calculate the img bloks
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)

# calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
Expand Down
17 changes: 7 additions & 10 deletions comfy/ldm/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,10 @@

from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management

import comfy.ops as ops

def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape

if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)

q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
Expand All @@ -35,7 +27,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)

@torch.cuda.nvtx.range("apply_rope1")
def apply_rope1(x: Tensor, freqs_cis: Tensor):
if ops._CK_AVAILABLE:
return ops.ck.apply_rope1(x, freqs_cis)
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)

x_out = freqs_cis[..., 0] * x_[..., 0]
Expand All @@ -44,4 +39,6 @@ def apply_rope1(x: Tensor, freqs_cis: Tensor):
return x_out.reshape(*x.shape).type_as(x)

def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
if ops._CK_AVAILABLE:
return ops.ck.apply_rope(xq, xk, freqs_cis)
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
95 changes: 39 additions & 56 deletions comfy/ldm/lightricks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple

from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords

from comfy.ldm.flux.math import apply_rope1

def get_timestep_embedding(
timesteps: torch.Tensor,
Expand Down Expand Up @@ -238,20 +237,6 @@ def forward(self, x):
return self.net(x)


def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]

t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")

out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs

return out


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
Expand Down Expand Up @@ -281,8 +266,8 @@ def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
k = self.k_norm(k)

if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)

if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
Expand All @@ -306,12 +291,17 @@ def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None,
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)

x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
norm_x = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa)
attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_result, gate_msa)

x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)

y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
norm_x = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp)
ff_result = self.ff(y)
x.addcmul_(ff_result, gate_mlp)

return x

Expand All @@ -325,43 +315,36 @@ def get_fractional_positions(indices_grid, max_pos):
)
return fractional_positions


def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype

dtype = torch.float32
device = indices_grid.device

# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)

start = 1
end = theta
device = fractional_positions.device

indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)

indices = indices * math.pi / 2

freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)

cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2

# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)

# Pad if dim is not divisible by 6
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)

# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]

# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]

return freqs_cis.to(out_dtype)


class LTXVModel(torch.nn.Module):
Expand Down Expand Up @@ -501,7 +484,7 @@ def block_wrap(args):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x)

x = self.patchifier.unpatchify(
Expand Down
40 changes: 22 additions & 18 deletions comfy/ldm/qwen_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import comfy.ldm.common_dit
import comfy.patcher_extension

from comfy.ldm.flux.math import apply_rope1

class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
Expand Down Expand Up @@ -56,7 +58,7 @@ def apply_rotary_emb(x, freqs_cis):

t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
return t_out.reshape(*x.shape).to(x.dtype)


class QwenTimestepProjEmbeddings(nn.Module):
Expand Down Expand Up @@ -134,33 +136,34 @@ def forward(
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]

img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)

txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)

img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)

joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)

joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)

joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)

joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)

txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
Expand Down Expand Up @@ -413,7 +416,8 @@ def _forward(
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
# image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids

hidden_states = self.img_in(hidden_states)
Expand Down
1 change: 1 addition & 0 deletions comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def forward(
# assert e[0].dtype == torch.float32

# self-attention
x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
Expand Down
11 changes: 10 additions & 1 deletion comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
import comfy.float
import comfy.rmsnorm
import contextlib
try:
import comfy_kitchen as ck
ck.disable_backend("cutile")
_CK_AVAILABLE = True
except:
_CK_AVAILABLE = False

def run_every_op():
comfy.model_management.throw_exception_if_processing_interrupted()
Expand Down Expand Up @@ -352,7 +358,10 @@ def fp8_linear(self, input):
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
if _CK_AVAILABLE:
input = ck.quantize_per_tensor_fp8(input, scale_input).reshape(-1, input_shape[2])
else:
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()

if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
Expand Down
Loading