From 59c482fd77c27a65d743c6140da75e0abdd91f41 Mon Sep 17 00:00:00 2001 From: lspindler Date: Thu, 16 Oct 2025 16:07:43 +0200 Subject: [PATCH 01/49] Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. --- comfy/model_base.py | 10 +- comfy/model_detection.py | 125 +++++++++ comfy/ops.py | 484 ++++++++++++++++++++++++++++++++- comfy/supported_models_base.py | 1 + 4 files changed, 618 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea192..7b4651f8eb1b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -134,7 +134,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -326,6 +326,14 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_ if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) + + # Save mixed precision metadata + if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: + metadata = { + "format_version": "1.0", + "layers": self.model_config.layer_quant_config + } + unet_state_dict["_quantization_metadata"] = metadata unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7677617c0efb..7ce9aaa9af4a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,6 +6,125 @@ import logging import torch + +# ============================================================================== +# Quantization Detection Functions +# ============================================================================== + +def normalize_layer_name(full_key, known_prefixes): + """ + Strip model prefix and parameter suffix from a state dict key. + + Args: + full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight") + known_prefixes: List of known model prefixes to strip + + Returns: + Normalized layer name (e.g., "layer1") + """ + name = full_key + + # Strip model prefix + for prefix in known_prefixes: + if name.startswith(prefix): + name = name[len(prefix):] + break + + # Remove parameter suffix + for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]: + if name.endswith(suffix): + name = name[:-len(suffix)] + break + + return name + + +def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): + """ + Detect per-layer quantization configuration from state dict. + + Detection priority: + 1. Check for _quantization_metadata key (new format) + 2. Check for scaled_fp8 key (legacy format - return None) + 3. Check for per-layer scale_weight patterns (mixed detection) + 4. No quantization detected (return None) + + Args: + state_dict: Model state dictionary + prefix: Key prefix for model layers + + Returns: + Dict mapping layer names to quantization configs, or None for legacy/no quantization. + + Example return value: + { + "input_blocks.5.1.transformer_blocks.0.attn1.to_q": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": True} + }, + "middle_block.1.transformer_blocks.0.attn2.to_k": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": True} + } + } + """ + + # 1. Check for new metadata format + metadata_key = f"{prefix}_quantization_metadata" + if metadata_key in state_dict: + try: + metadata = state_dict.pop(metadata_key) + if isinstance(metadata, dict) and "layers" in metadata: + logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})") + return metadata["layers"] + else: + logging.warning(f"Invalid quantization metadata format, ignoring") + except Exception as e: + logging.error(f"Failed to parse quantization metadata: {e}") + return None + + # 2. Check for legacy scaled_fp8 marker + # If present, return None to use legacy code path + scaled_fp8_key = f"{prefix}scaled_fp8" + if scaled_fp8_key in state_dict: + logging.debug("Detected legacy scaled_fp8 format, using legacy code path") + return None + + # 3. Check for per-layer scale patterns (mixed precision without metadata) + # Look for layers that have scale_weight but not all layers have it + known_prefixes = [prefix] + layer_configs = {} + layers_with_scale = set() + layers_with_weight = set() + + for key in state_dict.keys(): + if key.startswith(prefix): + if key.endswith(".scale_weight"): + layer_name = normalize_layer_name(key, known_prefixes) + layers_with_scale.add(layer_name) + # Detect format based on weight dtype + weight_key = f"{prefix}{layer_name}.weight" + if weight_key in state_dict: + weight_dtype = state_dict[weight_key].dtype + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled" + layer_configs[layer_name] = { + "format": format_name, + "params": {"use_fp8_matmul": True} + } + elif key.endswith(".weight") and not key.endswith(".scale_weight"): + layer_name = normalize_layer_name(key, known_prefixes) + layers_with_weight.add(layer_name) + + # If we found scale_weight on some but not all layers, it's mixed precision + if layer_configs and len(layers_with_scale) < len(layers_with_weight): + logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers") + return layer_configs + + # 4. No quantization detected + return None + + def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -701,6 +820,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal else: model_config.optimizations["fp8"] = True + # Detect per-layer quantization (mixed precision) + layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix) + if layer_quant_config: + model_config.layer_quant_config = layer_quant_config + logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + return model_config def unet_prefix_from_state_dict(state_dict): diff --git a/comfy/ops.py b/comfy/ops.py index b2096b40ee37..7ce7d3293451 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -458,7 +458,457 @@ def forward_comfy_cast_weights(self, input): def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): + +# ============================================================================== +# Quantization Format Registry System +# ============================================================================== + +class QuantFormatHandler: + """ + Base class for all quantization format handlers. + + A handler encapsulates the logic for a specific quantization format + (e.g., FP8 scaled, MX formats) and manages the quantization + parameters and forward pass for quantized layers. + """ + + def __init__(self, layer, **config): + """ + Initialize handler for a specific layer. + + Args: + layer: The nn.Module layer (Linear, Conv2d, etc.) + **config: Format-specific configuration + """ + self.layer = layer + self.config = config + + def setup_parameters(self): + """ + Initialize quantization parameters on the layer. + Called during layer construction or load_state_dict. + + Subclasses should create parameters like scale_weight, scale_input, etc. + and attach them to self.layer. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()") + + def forward(self, *args, **kwargs): + """ + Execute quantized forward pass. + + Signature matches the layer's expected forward pass. + Handler accesses layer parameters via self.layer (weight, bias, etc.) + + Args: + *args: Positional arguments matching layer forward signature + **kwargs: Keyword arguments matching layer forward signature + + Returns: + Layer output tensor + + Examples: + Linear: forward(input) + Conv2d: forward(input) + GroupNorm: forward(input) + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement forward()") + + def load_state_dict(self, state_dict, prefix): + """ + Load quantization parameters from state dict. + + Args: + state_dict: State dictionary + prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.") + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()") + + def state_dict(self, prefix): + """ + Save quantization parameters to state dict. + + Args: + prefix: Key prefix for this layer + + Returns: + Dictionary of quantization parameters with full keys + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()") + + def convert_weight(self, weight, inplace=False): + """ + Convert weight from quantized to full precision (dequantize). + + Args: + weight: Quantized weight tensor + inplace: Whether to modify in-place + + Returns: + Dequantized weight tensor + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()") + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): + """ + Convert and set weight from full precision to quantized. + + Args: + weight: Full precision weight tensor + inplace_update: Whether to update layer weight in-place + seed: Random seed for stochastic rounding + return_weight: If True, return quantized weight without setting + + Returns: + Quantized weight if return_weight=True, else None + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()") + + +class QuantFormatRegistry: + """ + Global registry for quantization formats. + + Formats are registered with a unique name and handler class. + Custom formats can be registered by custom nodes. + """ + + _formats = {} + + @classmethod + def register(cls, name, handler_class, **default_config): + """ + Register a new quantization format. + + Args: + name: Unique format identifier (e.g., "fp8_e4m3fn_scaled") + handler_class: Handler class implementing QuantFormatHandler + **default_config: Default configuration parameters + + Example: + QuantFormatRegistry.register( + "fp8_e4m3fn_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=False, + use_fp8_matmul=True, + ) + """ + if not issubclass(handler_class, QuantFormatHandler): + raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}") + + cls._formats[name] = { + "handler": handler_class, + "config": default_config.copy() + } + logging.debug(f"Registered quantization format: {name}") + + @classmethod + def get(cls, name, **override_config): + """ + Get format info with optional config overrides. + + Args: + name: Format identifier + **override_config: Configuration overrides + + Returns: + Dict with 'handler' (class) and 'config' (dict) keys + + Raises: + ValueError: If format name not registered + """ + if name not in cls._formats: + available = ", ".join(cls._formats.keys()) if cls._formats else "none" + raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}") + + format_info = cls._formats[name].copy() + # Merge override_config into default config + config = format_info["config"].copy() + config.update(override_config) + format_info["config"] = config + return format_info + + @classmethod + def list_formats(cls): + """List all registered format names""" + return list(cls._formats.keys()) + + @classmethod + def is_registered(cls, name): + """Check if a format is registered""" + return name in cls._formats + + +class FP8ScaledHandler(QuantFormatHandler): + """ + Handler for FP8 quantization with per-tensor scaling. + + Supports both weight-only and weight+activation quantization. + Compatible with existing fp8_linear implementation. + """ + + def setup_parameters(self): + """Initialize scale_weight and optionally scale_input""" + device = self.layer.weight.device + dtype = torch.float32 + + # Always have scale_weight for FP8 + if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None: + self.layer.scale_weight = torch.nn.Parameter( + torch.ones((), device=device, dtype=dtype), + requires_grad=False + ) + + # scale_input is optional (for activation quantization) + if self.config.get("quantize_activation", False): + if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None: + self.layer.scale_input = torch.nn.Parameter( + torch.ones((), device=device, dtype=dtype), + requires_grad=False + ) + else: + self.layer.scale_input = None + + def forward(self, *args, **kwargs): + """ + FP8 forward pass with optional activation quantization. + Supports Linear layers (Conv2d in future). + """ + # Detect layer type and dispatch + if isinstance(self.layer, torch.nn.Linear): + return self._forward_linear(*args, **kwargs) + else: + raise NotImplementedError( + f"FP8ScaledHandler not implemented for {type(self.layer).__name__}" + ) + + def _forward_linear(self, input): + """FP8 forward for Linear layers""" + # Try fast path with fp8_linear if enabled + if self.config.get("use_fp8_matmul", False) and not self.layer.training: + try: + result = fp8_linear(self.layer, input) + if result is not None: + return result + except Exception as e: + logging.debug(f"FP8 matmul failed, falling back to standard path: {e}") + + # Standard path: dequantize and compute + weight, bias = cast_bias_weight(self.layer, input) + + # Dequantize weight + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + + # Apply weight functions (LoRA, etc.) - they see dequantized weights + if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0: + weight = weight * scale + for f in self.layer.weight_function: + weight = f(weight) + else: + weight = weight * scale + + if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0: + for f in self.layer.bias_function: + bias = f(bias) if bias is not None else None + + # Execute linear operation + # Optimization: multiply by scale on smaller tensor + if weight.numel() < input.numel() and len(self.layer.weight_function) == 0: + return torch.nn.functional.linear(input, weight, bias) + else: + return torch.nn.functional.linear(input, weight, bias) + + def load_state_dict(self, state_dict, prefix): + """Load scale parameters from state dict""" + scale_weight_key = f"{prefix}scale_weight" + if scale_weight_key in state_dict: + self.layer.scale_weight.data.copy_(state_dict[scale_weight_key]) + + scale_input_key = f"{prefix}scale_input" + if scale_input_key in state_dict and self.layer.scale_input is not None: + self.layer.scale_input.data.copy_(state_dict[scale_input_key]) + + def state_dict(self, prefix): + """Save scale parameters to state dict""" + result = {f"{prefix}scale_weight": self.layer.scale_weight} + if self.layer.scale_input is not None: + result[f"{prefix}scale_input"] = self.layer.scale_input + return result + + def convert_weight(self, weight, inplace=False): + """Dequantize: multiply by scale""" + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + if inplace: + weight *= scale + return weight + return weight * scale + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): + """Quantize: divide by scale with stochastic rounding""" + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + quantized = comfy.float.stochastic_rounding( + weight / scale, + self.layer.weight.dtype, + seed=seed + ) + + if return_weight: + return quantized + + if inplace_update: + self.layer.weight.data.copy_(quantized) + else: + self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False) + + +# ============================================================================== +# Mixed Precision Operations +# ============================================================================== + +class MixedPrecisionOps(disable_weight_init): + """ + Operations class supporting per-layer quantization (mixed precision). + + This class enables different layers to use different quantization formats + within the same model (e.g., some layers FP8, others BF16). + + Layer-specific quantization is configured via _layer_quant_config class variable, + which is set by pick_operations() when a model has mixed precision. + """ + + _layer_quant_config = {} # Class variable set by pick_operations() + + class Linear(disable_weight_init.Linear): + """Linear layer with optional per-layer quantization""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.quant_handler = None + self._handler_initialized = False + + def reset_parameters(self): + # Don't allocate weights - return None like disable_weight_init + return None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + """ + Called by PyTorch during load_state_dict. + This is where we initialize the handler since we now know the layer name. + """ + if not self._handler_initialized: + # Normalize layer name from prefix + layer_name = prefix.rstrip('.') + + # Strip known model prefixes + for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: + if layer_name.startswith(model_prefix): + layer_name = layer_name[len(model_prefix):] + break + + # Check if this layer has quantization config + # Access via parent class since _layer_quant_config is a class variable + if layer_name in MixedPrecisionOps._layer_quant_config: + config = MixedPrecisionOps._layer_quant_config[layer_name] + try: + format_info = QuantFormatRegistry.get( + config["format"], + **config.get("params", {}) + ) + + # Initialize handler + self.quant_handler = format_info["handler"](self, **format_info["config"]) + self.quant_handler.setup_parameters() + + # Let handler load its parameters (scale_weight, etc.) + self.quant_handler.load_state_dict(state_dict, prefix) + + logging.debug(f"Initialized {config['format']} handler for layer {layer_name}") + except ValueError as e: + # Format not registered - fall back to standard precision + logging.warning( + f"Quantization format '{config['format']}' not registered for layer {layer_name}. " + f"Falling back to standard precision. Error: {e}" + ) + self.quant_handler = None + except Exception as e: + logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}") + self.quant_handler = None + + self._handler_initialized = True + + # Call parent to load weight and bias + super()._load_from_state_dict( + state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + """Save layer parameters including quantization metadata""" + super()._save_to_state_dict(destination, prefix, keep_vars) + + # Save handler parameters (scale_weight, etc.) + if self.quant_handler: + handler_dict = self.quant_handler.state_dict(prefix) + destination.update(handler_dict) + + def forward_comfy_cast_weights(self, input): + """Forward pass with optional quantization""" + if self.quant_handler: + # Use handler for quantized forward + return self.quant_handler.forward(input) + else: + # Standard path for non-quantized layers + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + """Main forward pass""" + run_every_op() + # Same logic as disable_weight_init.Linear + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + @classmethod + def conv_nd(s, dims, *args, **kwargs): + """Create Conv layer (same as disable_weight_init)""" + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + """ + Select appropriate operations class for model. + + NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3). + LEGACY: All other paths unchanged for backward compatibility. + + Args: + weight_dtype: Weight storage dtype + compute_dtype: Computation dtype + load_device: Device for loading + disable_fast_fp8: Disable fast FP8 paths + fp8_optimizations: Enable FP8 optimizations + scaled_fp8: Legacy FP8 dtype marker + model_config: Model config object (optional, for mixed precision support) + + Returns: + Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init) + """ + # NEW: Check for mixed precision + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: + MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") + return MixedPrecisionOps + + # LEGACY paths (unchanged) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) @@ -483,3 +933,35 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ return disable_weight_init return manual_cast + + +# ============================================================================== +# Register built-in quantization formats +# ============================================================================== + +# FP8 E4M3FN weight-only quantization +QuantFormatRegistry.register( + "fp8_e4m3fn_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=False, + use_fp8_matmul=True, +) + +# FP8 E4M3FN weight+activation quantization +QuantFormatRegistry.register( + "fp8_e4m3fn_scaled_dynamic", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=True, + use_fp8_matmul=True, +) + +# FP8 E5M2 weight-only quantization +QuantFormatRegistry.register( + "fp8_e5m2_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e5m2, + quantize_activation=False, + use_fp8_matmul=True, +) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb110d..e4bd7451429b 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -50,6 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None scaled_fp8 = None + layer_quant_config = None # Per-layer quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod From bc0ad9bb49b642e081f99f92d239d634988d52bc Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 16 Oct 2025 20:12:50 +0300 Subject: [PATCH 02/49] fix(api-nodes): remove "veo2" model from Veo3 node (#10372) --- comfy_api_nodes/nodes_veo2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 4ab5c518614d..daeaa823e44e 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -393,7 +393,9 @@ def define_schema(cls): ), IO.Combo.Input( "model", - options=list(MODELS_MAP.keys()), + options=[ + "veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001" + ], default="veo-3.0-generate-001", tooltip="Veo 3 model to use for video generation", optional=True, From 19b466160c1cd43f707769adef6f8ed6e9fd50bf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:16:03 -0700 Subject: [PATCH 03/49] Workaround for nvidia issue where VAE uses 3x more memory on torch 2.9 (#10373) --- comfy/ops.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index b2096b40ee37..893ceda98463 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -52,6 +52,16 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") +NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False +try: + if comfy.model_management.is_nvidia(): + if torch.backends.cudnn.version() >= 91300 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): + #TODO: change upper bound version once it's fixed' + NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True + logging.info("working around nvidia conv3d memory bug.") +except: + pass + cast_to = comfy.model_management.cast_to #TODO: remove once no more references if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: @@ -151,6 +161,15 @@ class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): def reset_parameters(self): return None + def _conv_forward(self, input, weight, bias, *args, **kwargs): + if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16): + out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True) + if bias is not None: + out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) + return out + else: + return super()._conv_forward(input, weight, bias, *args, **kwargs) + def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) From b1293d50eff5f1ff2e54f73114fbe7c0f9aef8fe Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:59:56 -0700 Subject: [PATCH 04/49] workaround also works on cudnn 91200 (#10375) --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 893ceda98463..56b07b44cbe7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -55,7 +55,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False try: if comfy.model_management.is_nvidia(): - if torch.backends.cudnn.version() >= 91300 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): + if torch.backends.cudnn.version() >= 91200 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): #TODO: change upper bound version once it's fixed' NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True logging.info("working around nvidia conv3d memory bug.") From d8d60b56093a15edc5d25486d387d3c5917dc3d3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 16 Oct 2025 21:39:37 -0700 Subject: [PATCH 05/49] Do batch_slice in EasyCache's apply_cache_diff (#10376) --- comfy_extras/nodes_easycache.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index c170e9fd9be8..1359e2f99339 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -244,6 +244,8 @@ def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): self.total_steps_skipped += 1 batch_offset = x.shape[0] // len(uuids) for i, uuid in enumerate(uuids): + # slice out only what is relevant to this cond + batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)] # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: if not self.allow_mismatch: @@ -261,9 +263,8 @@ def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): slicing.append(slice(None, dim_u)) else: slicing.append(slice(None)) - slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing - x = x[slicing] - x += self.uuid_cache_diffs[uuid].to(x.device) + batch_slice = batch_slice + slicing + x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device) return x def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): From b1467da4803017a418c32c159525767f45871ca3 Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Sat, 18 Oct 2025 06:55:15 +1000 Subject: [PATCH 06/49] execution: fold in dependency aware caching / Fix --cache-none with loops/lazy etc (#10368) * execution: fold in dependency aware caching This makes --cache-none compatiable with lazy and expanded subgraphs. Currently the --cache-none option is powered by the DependencyAwareCache. The cache attempts to maintain a parallel copy of the execution list data structure, however it is only setup once at the start of execution and does not get meaninigful updates to the execution list. This causes multiple problems when --cache-none is used with lazy and expanded subgraphs as the DAC does not accurately update its copy of the execution data structure. DAC has an attempt to handle subgraphs ensure_subcache however this does not accurately connect to nodes outside the subgraph. The current semantics of DAC are to free a node ASAP after the dependent nodes are executed. This means that if a subgraph refs such a node it will be requed and re-executed by the execution_list but DAC wont see it in its to-free lists anymore and leak memory. Rather than try and cover all the cases where the execution list changes from inside the cache, move the while problem to the executor which maintains an always up-to-date copy of the wanted data-structure. The executor now has a fast-moving run-local cache of its own. Each _to node has its own mini cache, and the cache is unconditionally primed at the time of add_strong_link. add_strong_link is called for all of static workflows, lazy links and expanded subgraphs so its the singular source of truth for output dependendencies. In the case of a cache-hit, the executor cache will hold the non-none value (it will respect updates if they happen somehow as well). In the case of a cache-miss, the executor caches a None and will wait for a notification to update the value when the node completes. When a node completes execution, it simply releases its mini-cache and in turn its strong refs on its direct anscestor outputs, allowing for ASAP freeing (same as the DependencyAwareCache but a little more automatic). This now allows for re-implementation of --cache-none with no cache at all. The dependency aware cache was also observing the dependency sematics for the objects and UI cache which is not accurate (this entire logic was always outputs specific). This also prepares for more complex caching strategies (such as RAM pressure based caching), where a cache can implement any freeing strategy completely independently of the DepedancyAwareness requirement. * main: re-implement --cache-none as no cache at all The execution list now tracks the dependency aware caching more correctly that the DependancyAwareCache. Change it to a cache that does nothing. * test_execution: add --cache-none to the test suite --cache-none is now expected to work universally. Run it through the full unit test suite. Propagate the server parameterization for whether or not the server is capabale of caching, so that the minority of tests that specifically check for cache hits can if else. Hard assert NOT caching in the else to give some coverage of --cache-none expected behaviour to not acutally cache. --- comfy_execution/caching.py | 174 ++++-------------------------- comfy_execution/graph.py | 31 +++++- execution.py | 34 +++--- main.py | 2 +- tests/execution/test_execution.py | 50 +++++---- 5 files changed, 101 insertions(+), 190 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 41224ce3b82e..566bc3f9c74a 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -265,6 +265,26 @@ async def ensure_subcache_for(self, node_id, children_ids): assert cache is not None return await cache._ensure_subcache(node_id, children_ids) +class NullCache: + + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + pass + + def all_node_ids(self): + return [] + + def clean_unused(self): + pass + + def get(self, node_id): + return None + + def set(self, node_id, value): + pass + + async def ensure_subcache_for(self, node_id, children_ids): + return self + class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -316,157 +336,3 @@ async def ensure_subcache_for(self, node_id, children_ids): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self - - -class DependencyAwareCache(BasicCache): - """ - A cache implementation that tracks dependencies between nodes and manages - their execution and caching accordingly. It extends the BasicCache class. - Nodes are removed from this cache once all of their descendants have been - executed. - """ - - def __init__(self, key_class): - """ - Initialize the DependencyAwareCache. - - Args: - key_class: The class used for generating cache keys. - """ - super().__init__(key_class) - self.descendants = {} # Maps node_id -> set of descendant node_ids - self.ancestors = {} # Maps node_id -> set of ancestor node_ids - self.executed_nodes = set() # Tracks nodes that have been executed - - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): - """ - Clear the entire cache and rebuild the dependency graph. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to initialize the cache for. - is_changed_cache: Flag indicating if the cache has changed. - """ - # Clear all existing cache data - self.cache.clear() - self.subcaches.clear() - self.descendants.clear() - self.ancestors.clear() - self.executed_nodes.clear() - - # Call the parent method to initialize the cache with the new prompt - await super().set_prompt(dynprompt, node_ids, is_changed_cache) - - # Rebuild the dependency graph - self._build_dependency_graph(dynprompt, node_ids) - - def _build_dependency_graph(self, dynprompt, node_ids): - """ - Build the dependency graph for all nodes. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to build the graph for. - """ - self.descendants.clear() - self.ancestors.clear() - for node_id in node_ids: - self.descendants[node_id] = set() - self.ancestors[node_id] = set() - - for node_id in node_ids: - inputs = dynprompt.get_node(node_id)["inputs"] - for input_data in inputs.values(): - if is_link(input_data): # Check if the input is a link to another node - ancestor_id = input_data[0] - self.descendants[ancestor_id].add(node_id) - self.ancestors[node_id].add(ancestor_id) - - def set(self, node_id, value): - """ - Mark a node as executed and store its value in the cache. - - Args: - node_id: The ID of the node to store. - value: The value to store for the node. - """ - self._set_immediate(node_id, value) - self.executed_nodes.add(node_id) - self._cleanup_ancestors(node_id) - - def get(self, node_id): - """ - Retrieve the cached value for a node. - - Args: - node_id: The ID of the node to retrieve. - - Returns: - The cached value for the node. - """ - return self._get_immediate(node_id) - - async def ensure_subcache_for(self, node_id, children_ids): - """ - Ensure a subcache exists for a node and update dependencies. - - Args: - node_id: The ID of the parent node. - children_ids: List of child node IDs to associate with the parent node. - - Returns: - The subcache object for the node. - """ - subcache = await super()._ensure_subcache(node_id, children_ids) - for child_id in children_ids: - self.descendants[node_id].add(child_id) - self.ancestors[child_id].add(node_id) - return subcache - - def _cleanup_ancestors(self, node_id): - """ - Check if ancestors of a node can be removed from the cache. - - Args: - node_id: The ID of the node whose ancestors are to be checked. - """ - for ancestor_id in self.ancestors.get(node_id, []): - if ancestor_id in self.executed_nodes: - # Remove ancestor if all its descendants have been executed - if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): - self._remove_node(ancestor_id) - - def _remove_node(self, node_id): - """ - Remove a node from the cache. - - Args: - node_id: The ID of the node to remove. - """ - cache_key = self.cache_key_set.get_data_key(node_id) - if cache_key in self.cache: - del self.cache[cache_key] - subcache_key = self.cache_key_set.get_subcache_key(node_id) - if subcache_key in self.subcaches: - del self.subcaches[subcache_key] - - def clean_unused(self): - """ - Clean up unused nodes. This is a no-op for this cache implementation. - """ - pass - - def recursive_debug_dump(self): - """ - Dump the cache and dependency graph for debugging. - - Returns: - A list containing the cache state and dependency graph. - """ - result = super().recursive_debug_dump() - result.append({ - "descendants": self.descendants, - "ancestors": self.ancestors, - "executed_nodes": list(self.executed_nodes), - }) - return result diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index f4b427265da7..d5bbacde3a02 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -153,8 +153,9 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): - node_ids.append(from_node_id) + if (include_lazy or not is_lazy): + if not self.is_cached(from_node_id): + node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: @@ -194,10 +195,34 @@ def __init__(self, dynprompt, output_cache): super().__init__(dynprompt) self.output_cache = output_cache self.staged_node_id = None + self.execution_cache = {} + self.execution_cache_listeners = {} def is_cached(self, node_id): return self.output_cache.get(node_id) is not None + def cache_link(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + self.execution_cache[to_node_id] = {} + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + if not from_node_id in self.execution_cache_listeners: + self.execution_cache_listeners[from_node_id] = set() + self.execution_cache_listeners[from_node_id].add(to_node_id) + + def get_output_cache(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + return None + return self.execution_cache[to_node_id].get(from_node_id) + + def cache_update(self, node_id, value): + if node_id in self.execution_cache_listeners: + for to_node_id in self.execution_cache_listeners[node_id]: + self.execution_cache[to_node_id][node_id] = value + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + super().add_strong_link(from_node_id, from_socket, to_node_id) + self.cache_link(from_node_id, to_node_id) + async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): @@ -277,6 +302,8 @@ def unstage_node_execution(self): def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) + self.execution_cache.pop(node_id, None) + self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None def get_nodes_in_cycle(self): diff --git a/execution.py b/execution.py index 1dc35738b823..78c36a4b0556 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ BasicCache, CacheKeySetID, CacheKeySetInputSignature, - DependencyAwareCache, + NullCache, HierarchicalCache, LRUCache, ) @@ -91,13 +91,13 @@ async def get(self, node_id): class CacheType(Enum): CLASSIC = 0 LRU = 1 - DEPENDENCY_AWARE = 2 + NONE = 2 class CacheSet: def __init__(self, cache_type=None, cache_size=None): - if cache_type == CacheType.DEPENDENCY_AWARE: - self.init_dependency_aware_cache() + if cache_type == CacheType.NONE: + self.init_null_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.LRU: if cache_size is None: @@ -120,11 +120,12 @@ def init_lru_cache(self, cache_size): self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) - # only hold cached items while the decendents have not executed - def init_dependency_aware_cache(self): - self.outputs = DependencyAwareCache(CacheKeySetInputSignature) - self.ui = DependencyAwareCache(CacheKeySetInputSignature) - self.objects = DependencyAwareCache(CacheKeySetID) + def init_null_cache(self): + self.outputs = NullCache() + #The UI cache is expected to be iterable at the end of each workflow + #so it must cache at least a full workflow. Use Heirachical + self.ui = HierarchicalCache(CacheKeySetInputSignature) + self.objects = NullCache() def recursive_debug_dump(self): result = { @@ -135,7 +136,7 @@ def recursive_debug_dump(self): SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) if is_v3: valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) @@ -153,10 +154,10 @@ def mark_missing(): if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if outputs is None: + if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = outputs.get(input_unique_id) + cached_output = execution_list.get_output_cache(input_unique_id, unique_id) if cached_output is None: mark_missing() continue @@ -405,6 +406,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) get_progress_state().finish_progress(unique_id) + execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -434,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = caches.outputs.get(source_node)[source_output] + node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] for o in node_output: resolved_output.append(o) @@ -446,7 +448,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -549,11 +551,15 @@ async def await_completion(): subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) + execution_list.cache_link(node_id, unique_id) for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) + caches.outputs.set(unique_id, output_data) + execution_list.cache_update(unique_id, output_data) + except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") diff --git a/main.py b/main.py index 35857dba8a6b..4b4c5dcc4294 100644 --- a/main.py +++ b/main.py @@ -173,7 +173,7 @@ def prompt_worker(q, server_instance): if args.cache_lru > 0: cache_type = execution.CacheType.LRU elif args.cache_none: - cache_type = execution.CacheType.DEPENDENCY_AWARE + cache_type = execution.CacheType.NONE e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index ef73ad9fdd83..ace0d2279093 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -152,12 +152,12 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - # (use_lru, lru_size) - (False, 0), - (True, 0), - (True, 100), + { "extra_args" : [], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, + { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) - def _server(self, args_pytest, request): + def server(self, args_pytest, request): # Start server pargs = [ 'python','main.py', @@ -167,12 +167,10 @@ def _server(self, args_pytest, request): '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] - use_lru, lru_size = request.param - if use_lru: - pargs += ['--cache-lru', str(lru_size)] + pargs += [ str(param) for param in request.param["extra_args"] ] print("Running server with args:", pargs) # noqa: T201 p = subprocess.Popen(pargs) - yield + yield request.param p.kill() torch.cuda.empty_cache() @@ -193,7 +191,7 @@ def start_client(self, listen:str, port:int): return comfy_client @fixture(scope="class", autouse=True) - def shared_client(self, args_pytest, _server): + def shared_client(self, args_pytest, server): client = self.start_client(args_pytest["listen"], args_pytest["port"]) yield client del client @@ -225,7 +223,7 @@ def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): assert result.did_run(mask) assert result.did_run(lazy_mix) - def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -237,9 +235,12 @@ def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + else: + assert result2.did_run(node), f"Node {node_id} was cached, but should have been run" - def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -251,8 +252,12 @@ def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - assert not result2.did_run(input1), "Input1 should have been cached" - assert not result2.did_run(input2), "Input2 should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + else: + assert result2.did_run(input1), "Input1 should have been rerun" + assert result2.did_run(input2), "Input2 should have been rerun" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -411,7 +416,7 @@ def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder): input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) client.run(g) - def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -427,7 +432,10 @@ def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): result3 = client.run(g) result4 = client.run(g) assert result1.did_run(is_changed), "is_changed should have been run" - assert not result2.did_run(is_changed), "is_changed should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(is_changed), "is_changed should have been cached" + else: + assert result2.did_run(is_changed), "is_changed should have been re-run" assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" @@ -514,7 +522,7 @@ def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): assert len(images2) == 1, "Should have 1 image" # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): + def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -530,7 +538,11 @@ def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilde images = result.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - assert not result.did_run(test_node), "The execution should have been cached" + if server["should_cache_results"]: + assert not result.did_run(test_node), "The execution should have been cached" + else: + assert result.did_run(test_node), "The execution should have been re-run" + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized From 99ce2a1f66c4bcd500d76cc9a7430f7b2bf32776 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 18 Oct 2025 00:13:05 +0300 Subject: [PATCH 07/49] convert nodes_controlnet.py to V3 schema (#10202) --- comfy_extras/nodes_controlnet.py | 94 ++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 35 deletions(-) diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index 2d20e1fed7c2..e835feed71fc 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -1,20 +1,26 @@ from comfy.cldm.control_types import UNION_CONTROLNET_TYPES import nodes import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class SetUnionControlNetType: +class SetUnionControlNetType(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"control_net": ("CONTROL_NET", ), - "type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),) - }} + def define_schema(cls): + return io.Schema( + node_id="SetUnionControlNetType", + category="conditioning/controlnet", + inputs=[ + io.ControlNet.Input("control_net"), + io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())), + ], + outputs=[ + io.ControlNet.Output(), + ], + ) - CATEGORY = "conditioning/controlnet" - RETURN_TYPES = ("CONTROL_NET",) - - FUNCTION = "set_controlnet_type" - - def set_controlnet_type(self, control_net, type): + @classmethod + def execute(cls, control_net, type) -> io.NodeOutput: control_net = control_net.copy() type_number = UNION_CONTROLNET_TYPES.get(type, -1) if type_number >= 0: @@ -22,27 +28,36 @@ def set_controlnet_type(self, control_net, type): else: control_net.set_extra_arg("control_type", []) - return (control_net,) + return io.NodeOutput(control_net) + + set_controlnet_type = execute # TODO: remove + -class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced): +class ControlNetInpaintingAliMamaApply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "vae": ("VAE", ), - "image": ("IMAGE", ), - "mask": ("MASK", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - - FUNCTION = "apply_inpaint_controlnet" - - CATEGORY = "conditioning/controlnet" - - def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent): + def define_schema(cls): + return io.Schema( + node_id="ControlNetInpaintingAliMamaApply", + category="conditioning/controlnet", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.ControlNet.Input("control_net"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Mask.Input("mask"), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput: extra_concat = [] if control_net.concat_mask: mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) @@ -50,11 +65,20 @@ def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3]) extra_concat = [mask] - return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) + result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) + return io.NodeOutput(result[0], result[1]) + + apply_inpaint_controlnet = execute # TODO: remove + +class ControlNetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SetUnionControlNetType, + ControlNetInpaintingAliMamaApply, + ] -NODE_CLASS_MAPPINGS = { - "SetUnionControlNetType": SetUnionControlNetType, - "ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply, -} +async def comfy_entrypoint() -> ControlNetExtension: + return ControlNetExtension() From 92d97380bd02d9883295aeb2d29365cecd9a765e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 17 Oct 2025 15:22:59 -0700 Subject: [PATCH 08/49] Update Python 3.14 installation instructions (#10385) Removed mention of installing pytorch nightly for Python 3.14. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b0731db33bb0..c9a0644e33b5 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ comfy install ## Manual Install (Windows, Linux) -Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) and install pytorch nightly but it is not recommended. +Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended. Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 From 9da397ea2f271080406f0c14cf4f0db7221ddf70 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 17 Oct 2025 17:03:28 -0700 Subject: [PATCH 09/49] Disable torch compiler for cast_bias_weight function (#10384) * Disable torch compiler for cast_bias_weight function * Fix torch compile. --- comfy/ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 56b07b44cbe7..5feeb3571010 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -25,6 +25,9 @@ import contextlib def run_every_op(): + if torch.compiler.is_compiling(): + return + comfy.model_management.throw_exception_if_processing_interrupted() def scaled_dot_product_attention(q, k, v, *args, **kwargs): @@ -70,6 +73,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) +@torch.compiler.disable() def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: From 5b80addafd24bda5b2f9f7a35e32dbd40823c3fd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 18 Oct 2025 19:35:46 -0700 Subject: [PATCH 10/49] Turn off cuda malloc by default when --fast autotune is turned on. (#10393) --- comfy/model_management.py | 3 +++ comfy/ops.py | 3 --- cuda_malloc.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d82d5b8b00ae..7467391cd9bf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -371,6 +371,9 @@ def amd_min_version(device=None, min_rdna_version=0): except: pass +if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: + torch.backends.cudnn.benchmark = True + try: if torch_version_numeric >= (2, 5): torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) diff --git a/comfy/ops.py b/comfy/ops.py index 5feeb3571010..967134f0522a 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -67,9 +67,6 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): cast_to = comfy.model_management.cast_to #TODO: remove once no more references -if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: - torch.backends.cudnn.benchmark = True - def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) diff --git a/cuda_malloc.py b/cuda_malloc.py index c1d9ae3cab5c..6520d51230af 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,6 +1,6 @@ import os import importlib.util -from comfy.cli_args import args +from comfy.cli_args import args, PerformanceFeature import subprocess #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. @@ -75,8 +75,9 @@ def cuda_malloc_supported(): spec.loader.exec_module(module) version = module.__version__ - if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch - args.cuda_malloc = cuda_malloc_supported() + if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch + if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc + args.cuda_malloc = cuda_malloc_supported() except: pass From 0cf33953a7c951d163088cbfe36c55d1cdf8a718 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 18 Oct 2025 20:15:34 -0700 Subject: [PATCH 11/49] Fix batch size above 1 giving bad output in chroma radiance. (#10394) --- comfy/ldm/chroma_radiance/model.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 47aa11b04545..7d7be80f5943 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -189,15 +189,15 @@ def forward_nerf( nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + # Reshape for per-patch processing + nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) + if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size: # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than # the tile size. - img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) + img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params) else: - # Reshape for per-patch processing - nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) - nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) - # Get DCT-encoded pixel embeddings [pixel-dct] img_dct = self.nerf_image_embedder(nerf_pixels) @@ -240,17 +240,8 @@ def forward_tiled_nerf( end = min(i + tile_size, num_patches) # Slice the current tile from the input tensors - nerf_hidden_tile = nerf_hidden[:, i:end, :] - nerf_pixels_tile = nerf_pixels[:, i:end, :] - - # Get the actual number of patches in this tile (can be smaller for the last tile) - num_patches_tile = nerf_hidden_tile.shape[1] - - # Reshape the tile for per-patch processing - # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] - nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size) - # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] - nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) + nerf_hidden_tile = nerf_hidden[i * batch:end * batch] + nerf_pixels_tile = nerf_pixels[i * batch:end * batch] # get DCT-encoded pixel embeddings [pixel-dct] img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) From dad076aee68ab676fb390d9663ab9e343824a080 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 18 Oct 2025 20:19:52 -0700 Subject: [PATCH 12/49] Speed up chroma radiance. (#10395) --- comfy/model_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7677617c0efb..141f1e164834 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -213,7 +213,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_depth"] = 4 dit_config["nerf_max_freqs"] = 8 - dit_config["nerf_tile_size"] = 32 + dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 else: From b4f30bd4087a79b4c4fc89bb67b9889adb866294 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 18 Oct 2025 22:25:35 -0700 Subject: [PATCH 13/49] Pytorch is stupid. (#10398) --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 967134f0522a..934e21261edc 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -58,7 +58,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False try: if comfy.model_management.is_nvidia(): - if torch.backends.cudnn.version() >= 91200 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): + if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): #TODO: change upper bound version once it's fixed' NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True logging.info("working around nvidia conv3d memory bug.") From b5c59b763c6b14e1362ec4274b09eca4f3f7091b Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sun, 19 Oct 2025 13:05:46 -0700 Subject: [PATCH 14/49] Deprecation warning on unused files (#10387) * only warn for unused files * include internal extensions --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index a44f4f2372e7..10c2698b55ab 100644 --- a/server.py +++ b/server.py @@ -56,7 +56,7 @@ async def deprecation_warning(request: web.Request, handler): """Middleware to warn about deprecated frontend API paths""" path = request.path - if (path.startswith('/scripts/') or path.startswith('/extensions/core/')): + if path.startswith("/scripts/ui") or path.startswith("/extensions/core/"): # Only warn once per unique file path if path not in _deprecated_paths_warned: _deprecated_paths_warned.add(path) From a4787ac83bf6c83eeb459ed80fc9b36f63d2a3a7 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 21 Oct 2025 03:28:36 +0800 Subject: [PATCH 15/49] Update template to 0.2.1 (#10413) * Update template to 0.1.97 * Update template to 0.2.1 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 82457df54a74..dd2afcab0564 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.7 -comfyui-workflow-templates==0.1.95 +comfyui-workflow-templates==0.2.1 comfyui-embedded-docs==0.3.0 torch torchsde From 2c2aa409b01f513de88d2245931e5836ed1cd718 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 20 Oct 2025 12:43:24 -0700 Subject: [PATCH 16/49] Log message for cudnn disable on AMD. (#10418) --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7467391cd9bf..a2c318ec3e00 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -333,6 +333,7 @@ def amd_min_version(device=None, min_rdna_version=0): try: if is_amd(): torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) except: From b7992f871af38d89a459080caa57cc359ed93a46 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 20 Oct 2025 16:03:06 -0700 Subject: [PATCH 17/49] =?UTF-8?q?Revert=20"execution:=20fold=20in=20depend?= =?UTF-8?q?ency=20aware=20caching=20/=20Fix=20--cache-none=20with=20l?= =?UTF-8?q?=E2=80=A6"=20(#10422)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit b1467da4803017a418c32c159525767f45871ca3. --- comfy_execution/caching.py | 174 ++++++++++++++++++++++++++---- comfy_execution/graph.py | 31 +----- execution.py | 34 +++--- main.py | 2 +- tests/execution/test_execution.py | 50 ++++----- 5 files changed, 190 insertions(+), 101 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 566bc3f9c74a..41224ce3b82e 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -265,26 +265,6 @@ async def ensure_subcache_for(self, node_id, children_ids): assert cache is not None return await cache._ensure_subcache(node_id, children_ids) -class NullCache: - - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): - pass - - def all_node_ids(self): - return [] - - def clean_unused(self): - pass - - def get(self, node_id): - return None - - def set(self, node_id, value): - pass - - async def ensure_subcache_for(self, node_id, children_ids): - return self - class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -336,3 +316,157 @@ async def ensure_subcache_for(self, node_id, children_ids): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + + +class DependencyAwareCache(BasicCache): + """ + A cache implementation that tracks dependencies between nodes and manages + their execution and caching accordingly. It extends the BasicCache class. + Nodes are removed from this cache once all of their descendants have been + executed. + """ + + def __init__(self, key_class): + """ + Initialize the DependencyAwareCache. + + Args: + key_class: The class used for generating cache keys. + """ + super().__init__(key_class) + self.descendants = {} # Maps node_id -> set of descendant node_ids + self.ancestors = {} # Maps node_id -> set of ancestor node_ids + self.executed_nodes = set() # Tracks nodes that have been executed + + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + """ + Clear the entire cache and rebuild the dependency graph. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to initialize the cache for. + is_changed_cache: Flag indicating if the cache has changed. + """ + # Clear all existing cache data + self.cache.clear() + self.subcaches.clear() + self.descendants.clear() + self.ancestors.clear() + self.executed_nodes.clear() + + # Call the parent method to initialize the cache with the new prompt + await super().set_prompt(dynprompt, node_ids, is_changed_cache) + + # Rebuild the dependency graph + self._build_dependency_graph(dynprompt, node_ids) + + def _build_dependency_graph(self, dynprompt, node_ids): + """ + Build the dependency graph for all nodes. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to build the graph for. + """ + self.descendants.clear() + self.ancestors.clear() + for node_id in node_ids: + self.descendants[node_id] = set() + self.ancestors[node_id] = set() + + for node_id in node_ids: + inputs = dynprompt.get_node(node_id)["inputs"] + for input_data in inputs.values(): + if is_link(input_data): # Check if the input is a link to another node + ancestor_id = input_data[0] + self.descendants[ancestor_id].add(node_id) + self.ancestors[node_id].add(ancestor_id) + + def set(self, node_id, value): + """ + Mark a node as executed and store its value in the cache. + + Args: + node_id: The ID of the node to store. + value: The value to store for the node. + """ + self._set_immediate(node_id, value) + self.executed_nodes.add(node_id) + self._cleanup_ancestors(node_id) + + def get(self, node_id): + """ + Retrieve the cached value for a node. + + Args: + node_id: The ID of the node to retrieve. + + Returns: + The cached value for the node. + """ + return self._get_immediate(node_id) + + async def ensure_subcache_for(self, node_id, children_ids): + """ + Ensure a subcache exists for a node and update dependencies. + + Args: + node_id: The ID of the parent node. + children_ids: List of child node IDs to associate with the parent node. + + Returns: + The subcache object for the node. + """ + subcache = await super()._ensure_subcache(node_id, children_ids) + for child_id in children_ids: + self.descendants[node_id].add(child_id) + self.ancestors[child_id].add(node_id) + return subcache + + def _cleanup_ancestors(self, node_id): + """ + Check if ancestors of a node can be removed from the cache. + + Args: + node_id: The ID of the node whose ancestors are to be checked. + """ + for ancestor_id in self.ancestors.get(node_id, []): + if ancestor_id in self.executed_nodes: + # Remove ancestor if all its descendants have been executed + if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): + self._remove_node(ancestor_id) + + def _remove_node(self, node_id): + """ + Remove a node from the cache. + + Args: + node_id: The ID of the node to remove. + """ + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + del self.cache[cache_key] + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + del self.subcaches[subcache_key] + + def clean_unused(self): + """ + Clean up unused nodes. This is a no-op for this cache implementation. + """ + pass + + def recursive_debug_dump(self): + """ + Dump the cache and dependency graph for debugging. + + Returns: + A list containing the cache state and dependency graph. + """ + result = super().recursive_debug_dump() + result.append({ + "descendants": self.descendants, + "ancestors": self.ancestors, + "executed_nodes": list(self.executed_nodes), + }) + return result diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index d5bbacde3a02..f4b427265da7 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -153,9 +153,8 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy): - if not self.is_cached(from_node_id): - node_ids.append(from_node_id) + if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): + node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: @@ -195,34 +194,10 @@ def __init__(self, dynprompt, output_cache): super().__init__(dynprompt) self.output_cache = output_cache self.staged_node_id = None - self.execution_cache = {} - self.execution_cache_listeners = {} def is_cached(self, node_id): return self.output_cache.get(node_id) is not None - def cache_link(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: - self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) - if not from_node_id in self.execution_cache_listeners: - self.execution_cache_listeners[from_node_id] = set() - self.execution_cache_listeners[from_node_id].add(to_node_id) - - def get_output_cache(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: - return None - return self.execution_cache[to_node_id].get(from_node_id) - - def cache_update(self, node_id, value): - if node_id in self.execution_cache_listeners: - for to_node_id in self.execution_cache_listeners[node_id]: - self.execution_cache[to_node_id][node_id] = value - - def add_strong_link(self, from_node_id, from_socket, to_node_id): - super().add_strong_link(from_node_id, from_socket, to_node_id) - self.cache_link(from_node_id, to_node_id) - async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): @@ -302,8 +277,6 @@ def unstage_node_execution(self): def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) - self.execution_cache.pop(node_id, None) - self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None def get_nodes_in_cycle(self): diff --git a/execution.py b/execution.py index 78c36a4b0556..1dc35738b823 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ BasicCache, CacheKeySetID, CacheKeySetInputSignature, - NullCache, + DependencyAwareCache, HierarchicalCache, LRUCache, ) @@ -91,13 +91,13 @@ async def get(self, node_id): class CacheType(Enum): CLASSIC = 0 LRU = 1 - NONE = 2 + DEPENDENCY_AWARE = 2 class CacheSet: def __init__(self, cache_type=None, cache_size=None): - if cache_type == CacheType.NONE: - self.init_null_cache() + if cache_type == CacheType.DEPENDENCY_AWARE: + self.init_dependency_aware_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.LRU: if cache_size is None: @@ -120,12 +120,11 @@ def init_lru_cache(self, cache_size): self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) - def init_null_cache(self): - self.outputs = NullCache() - #The UI cache is expected to be iterable at the end of each workflow - #so it must cache at least a full workflow. Use Heirachical - self.ui = HierarchicalCache(CacheKeySetInputSignature) - self.objects = NullCache() + # only hold cached items while the decendents have not executed + def init_dependency_aware_cache(self): + self.outputs = DependencyAwareCache(CacheKeySetInputSignature) + self.ui = DependencyAwareCache(CacheKeySetInputSignature) + self.objects = DependencyAwareCache(CacheKeySetID) def recursive_debug_dump(self): result = { @@ -136,7 +135,7 @@ def recursive_debug_dump(self): SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) if is_v3: valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) @@ -154,10 +153,10 @@ def mark_missing(): if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if execution_list is None: + if outputs is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = execution_list.get_output_cache(input_unique_id, unique_id) + cached_output = outputs.get(input_unique_id) if cached_output is None: mark_missing() continue @@ -406,7 +405,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) get_progress_state().finish_progress(unique_id) - execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -436,7 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] + node_output = caches.outputs.get(source_node)[source_output] for o in node_output: resolved_output.append(o) @@ -448,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -551,15 +549,11 @@ async def await_completion(): subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) - execution_list.cache_link(node_id, unique_id) for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) - caches.outputs.set(unique_id, output_data) - execution_list.cache_update(unique_id, output_data) - except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") diff --git a/main.py b/main.py index 4b4c5dcc4294..35857dba8a6b 100644 --- a/main.py +++ b/main.py @@ -173,7 +173,7 @@ def prompt_worker(q, server_instance): if args.cache_lru > 0: cache_type = execution.CacheType.LRU elif args.cache_none: - cache_type = execution.CacheType.NONE + cache_type = execution.CacheType.DEPENDENCY_AWARE e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index ace0d2279093..ef73ad9fdd83 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -152,12 +152,12 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - { "extra_args" : [], "should_cache_results" : True }, - { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, - { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, - { "extra_args" : ["--cache-none"], "should_cache_results" : False }, + # (use_lru, lru_size) + (False, 0), + (True, 0), + (True, 100), ]) - def server(self, args_pytest, request): + def _server(self, args_pytest, request): # Start server pargs = [ 'python','main.py', @@ -167,10 +167,12 @@ def server(self, args_pytest, request): '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] - pargs += [ str(param) for param in request.param["extra_args"] ] + use_lru, lru_size = request.param + if use_lru: + pargs += ['--cache-lru', str(lru_size)] print("Running server with args:", pargs) # noqa: T201 p = subprocess.Popen(pargs) - yield request.param + yield p.kill() torch.cuda.empty_cache() @@ -191,7 +193,7 @@ def start_client(self, listen:str, port:int): return comfy_client @fixture(scope="class", autouse=True) - def shared_client(self, args_pytest, server): + def shared_client(self, args_pytest, _server): client = self.start_client(args_pytest["listen"], args_pytest["port"]) yield client del client @@ -223,7 +225,7 @@ def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): assert result.did_run(mask) assert result.did_run(lazy_mix) - def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -235,12 +237,9 @@ def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - if server["should_cache_results"]: - assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" - else: - assert result2.did_run(node), f"Node {node_id} was cached, but should have been run" + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" - def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server): + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -252,12 +251,8 @@ def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server) client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - if server["should_cache_results"]: - assert not result2.did_run(input1), "Input1 should have been cached" - assert not result2.did_run(input2), "Input2 should have been cached" - else: - assert result2.did_run(input1), "Input1 should have been rerun" - assert result2.did_run(input2), "Input2 should have been rerun" + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -416,7 +411,7 @@ def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder): input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) client.run(g) - def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server): + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -432,10 +427,7 @@ def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, ser result3 = client.run(g) result4 = client.run(g) assert result1.did_run(is_changed), "is_changed should have been run" - if server["should_cache_results"]: - assert not result2.did_run(is_changed), "is_changed should have been cached" - else: - assert result2.did_run(is_changed), "is_changed should have been re-run" + assert not result2.did_run(is_changed), "is_changed should have been cached" assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" @@ -522,7 +514,7 @@ def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): assert len(images2) == 1, "Should have 1 image" # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): + def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -538,11 +530,7 @@ def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilde images = result.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - if server["should_cache_results"]: - assert not result.did_run(test_node), "The execution should have been cached" - else: - assert result.did_run(test_node), "The execution should have been re-run" - + assert not result.did_run(test_node), "The execution should have been cached" def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized From 560b1bdfca77d9441ca2924fd9d6baa8dda05cd7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Oct 2025 15:44:38 -0400 Subject: [PATCH 18/49] ComfyUI version v0.3.66 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index d39c1fdc46f5..33a06bbb08fc 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.65" +__version__ = "0.3.66" diff --git a/pyproject.toml b/pyproject.toml index 653604e24f0e..0c6b23a253d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.65" +version = "0.3.66" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 9cdc64998f8990aed7688b0ebe89bc3b97733764 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 21 Oct 2025 16:15:23 -0700 Subject: [PATCH 19/49] Only disable cudnn on newer AMD GPUs. (#10437) --- comfy/model_management.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a2c318ec3e00..79d6ff9d441d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -330,15 +330,21 @@ def amd_min_version(device=None, min_rdna_version=0): SUPPORT_FP8_OPS = args.supports_fp8_compute + +AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] + try: if is_amd(): - torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD - logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) except: rocm_version = (6, -1) - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: @@ -1331,7 +1337,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_amd(): arch = torch.cuda.get_device_properties(device).gcnArchName - if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 + if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16 if manual_cast: return True return False From f13cff0be65e35d34876b173bba2fec6bd94746b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 21 Oct 2025 20:16:16 -0700 Subject: [PATCH 20/49] Add custom node published subgraphs endpoint (#10438) * Add get_subgraphs_dir to ComfyExtension and PUBLISHED_SUBGRAPH_DIRS to nodes.py * Created initial endpoints, although the returned paths are a bit off currently * Fix path and actually return real data * Sanitize returned /api/global_subgraphs entries * Remove leftover function from early prototyping * Remove added whitespace * Add None check for sanitize_entry --- app/subgraph_manager.py | 112 ++++++++++++++++++++++++++++++++++++++++ server.py | 3 ++ 2 files changed, 115 insertions(+) create mode 100644 app/subgraph_manager.py diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py new file mode 100644 index 000000000000..dbe40454169e --- /dev/null +++ b/app/subgraph_manager.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import TypedDict +import os +import folder_paths +import glob +from aiohttp import web +import hashlib + + +class Source: + custom_node = "custom_node" + +class SubgraphEntry(TypedDict): + source: str + """ + Source of subgraph - custom_nodes vs templates. + """ + path: str + """ + Relative path of the subgraph file. + For custom nodes, will be the relative directory like /subgraphs/.json + """ + name: str + """ + Name of subgraph file. + """ + info: CustomNodeSubgraphEntryInfo + """ + Additional info about subgraph; in the case of custom_nodes, will contain nodepack name + """ + data: str + +class CustomNodeSubgraphEntryInfo(TypedDict): + node_pack: str + """Node pack name.""" + +class SubgraphManager: + def __init__(self): + self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None + + async def load_entry_data(self, entry: SubgraphEntry): + with open(entry['path'], 'r') as f: + entry['data'] = f.read() + return entry + + async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None: + if entry is None: + return None + entry = entry.copy() + entry.pop('path', None) + if remove_data: + entry.pop('data', None) + return entry + + async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]: + entries = entries.copy() + for key in list(entries.keys()): + entries[key] = await self.sanitize_entry(entries[key], remove_data) + return entries + + async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): + # if not forced to reload and cached, return cache + if not force_reload and self.cached_custom_node_subgraphs is not None: + return self.cached_custom_node_subgraphs + # Load subgraphs from custom nodes + subfolder = "subgraphs" + subgraphs_dict: dict[SubgraphEntry] = {} + + for folder in folder_paths.get_folder_paths("custom_nodes"): + pattern = os.path.join(folder, f"*/{subfolder}/*.json") + matched_files = glob.glob(pattern) + for file in matched_files: + # replace backslashes with forward slashes + file = file.replace('\\', '/') + info: CustomNodeSubgraphEntryInfo = { + "node_pack": "custom_nodes." + file.split('/')[-3] + } + source = Source.custom_node + # hash source + path to make sure id will be as unique as possible, but + # reproducible across backend reloads + id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() + entry: SubgraphEntry = { + "source": Source.custom_node, + "name": os.path.splitext(os.path.basename(file))[0], + "path": file, + "info": info, + } + subgraphs_dict[id] = entry + self.cached_custom_node_subgraphs = subgraphs_dict + return subgraphs_dict + + async def get_custom_node_subgraph(self, id: str, loadedModules): + subgraphs = await self.get_custom_node_subgraphs(loadedModules) + entry: SubgraphEntry = subgraphs.get(id, None) + if entry is not None and entry.get('data', None) is None: + await self.load_entry_data(entry) + return entry + + def add_routes(self, routes, loadedModules): + @routes.get("/global_subgraphs") + async def get_global_subgraphs(request): + subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) + # NOTE: we may want to include other sources of global subgraphs such as templates in the future; + # that's the reasoning for the current implementation + return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) + + @routes.get("/global_subgraphs/{id}") + async def get_global_subgraph(request): + id = request.match_info.get("id", None) + subgraph = await self.get_custom_node_subgraph(id, loadedModules) + return web.json_response(await self.sanitize_entry(subgraph)) diff --git a/server.py b/server.py index 10c2698b55ab..fe58db286fce 100644 --- a/server.py +++ b/server.py @@ -35,6 +35,7 @@ from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager +from app.subgraph_manager import SubgraphManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -173,6 +174,7 @@ def __init__(self, loop): self.user_manager = UserManager() self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() + self.subgraph_manager = SubgraphManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = execution.PromptQueue(self) @@ -819,6 +821,7 @@ def add_routes(self): self.user_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) + self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items()) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation. From ad6c14c37033f5921c37ae225d1fa8c5a5325e5a Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 22 Oct 2025 10:30:00 +0200 Subject: [PATCH 21/49] Updated design using Tensor Subclasses --- comfy/model_detection.py | 4 +- comfy/ops.py | 514 ++++-------------- comfy/quant_ops.py | 346 ++++++++++++ tests-unit/comfy_test/test_mixed_precision.py | 274 ++++++++++ tests-unit/comfy_test/test_quant_detection.py | 262 +++++++++ tests-unit/comfy_test/test_quant_registry.py | 399 ++++++++++++++ 6 files changed, 1400 insertions(+), 399 deletions(-) create mode 100644 comfy/quant_ops.py create mode 100644 tests-unit/comfy_test/test_mixed_precision.py create mode 100644 tests-unit/comfy_test/test_quant_detection.py create mode 100644 tests-unit/comfy_test/test_quant_registry.py diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7ce9aaa9af4a..01f26836b8d4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -107,10 +107,10 @@ def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): if weight_key in state_dict: weight_dtype = state_dict[weight_key].dtype if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled" + format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2" layer_configs[layer_name] = { "format": format_name, - "params": {"use_fp8_matmul": True} + "params": {} } elif key.endswith(".weight") and not key.endswith(".scale_weight"): layer_name = normalize_layer_name(key, known_prefixes) diff --git a/comfy/ops.py b/comfy/ops.py index 7ce7d3293451..2e6782dbd4f7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -324,6 +324,21 @@ class Embedding(disable_weight_init.Embedding): def fp8_linear(self, input): + """ + Legacy FP8 linear function - now uses tensor subclass infrastructure. + + This function maintains backward compatibility with existing code while + routing all FP8 computation through the unified tensor subclass system. + All actual FP8 matmul logic is handled by the registered operation handlers + in quant_ops.py via __torch_dispatch__. + + Args: + self: Linear layer with FP8 weight and scale parameters + input: Input tensor (any dtype) + + Returns: + Output tensor or None if weight is not FP8 + """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None @@ -335,10 +350,12 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype + if len(input.shape) == 3: + # Get weight and bias using standard casting w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - w = w.t() + # Get scales (same as before) scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: @@ -348,23 +365,31 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - - if bias is not None: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + + # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch + quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) + + # Handle input quantization and wrapping + if self.scale_input is None: + # Clamp input to FP8 range and quantize + input = torch.clamp(input, min=-448, max=448, out=input) + input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) - - if isinstance(o, tuple): - o = o[0] - + # Apply inverse scale and quantize + input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() + + # Wrap input in QuantizedTensorFP8 + quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype) + + # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! + # This is the key unification: all FP8 computation goes through one path + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + + # Reshape output if tensor_2d: return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None @@ -459,307 +484,8 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -# ============================================================================== -# Quantization Format Registry System -# ============================================================================== - -class QuantFormatHandler: - """ - Base class for all quantization format handlers. - - A handler encapsulates the logic for a specific quantization format - (e.g., FP8 scaled, MX formats) and manages the quantization - parameters and forward pass for quantized layers. - """ - - def __init__(self, layer, **config): - """ - Initialize handler for a specific layer. - - Args: - layer: The nn.Module layer (Linear, Conv2d, etc.) - **config: Format-specific configuration - """ - self.layer = layer - self.config = config - - def setup_parameters(self): - """ - Initialize quantization parameters on the layer. - Called during layer construction or load_state_dict. - - Subclasses should create parameters like scale_weight, scale_input, etc. - and attach them to self.layer. - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()") - - def forward(self, *args, **kwargs): - """ - Execute quantized forward pass. - - Signature matches the layer's expected forward pass. - Handler accesses layer parameters via self.layer (weight, bias, etc.) - - Args: - *args: Positional arguments matching layer forward signature - **kwargs: Keyword arguments matching layer forward signature - - Returns: - Layer output tensor - - Examples: - Linear: forward(input) - Conv2d: forward(input) - GroupNorm: forward(input) - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement forward()") - - def load_state_dict(self, state_dict, prefix): - """ - Load quantization parameters from state dict. - - Args: - state_dict: State dictionary - prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.") - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()") - - def state_dict(self, prefix): - """ - Save quantization parameters to state dict. - - Args: - prefix: Key prefix for this layer - - Returns: - Dictionary of quantization parameters with full keys - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()") - - def convert_weight(self, weight, inplace=False): - """ - Convert weight from quantized to full precision (dequantize). - - Args: - weight: Quantized weight tensor - inplace: Whether to modify in-place - - Returns: - Dequantized weight tensor - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()") - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): - """ - Convert and set weight from full precision to quantized. - - Args: - weight: Full precision weight tensor - inplace_update: Whether to update layer weight in-place - seed: Random seed for stochastic rounding - return_weight: If True, return quantized weight without setting - - Returns: - Quantized weight if return_weight=True, else None - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()") - - -class QuantFormatRegistry: - """ - Global registry for quantization formats. - - Formats are registered with a unique name and handler class. - Custom formats can be registered by custom nodes. - """ - - _formats = {} - - @classmethod - def register(cls, name, handler_class, **default_config): - """ - Register a new quantization format. - - Args: - name: Unique format identifier (e.g., "fp8_e4m3fn_scaled") - handler_class: Handler class implementing QuantFormatHandler - **default_config: Default configuration parameters - - Example: - QuantFormatRegistry.register( - "fp8_e4m3fn_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=False, - use_fp8_matmul=True, - ) - """ - if not issubclass(handler_class, QuantFormatHandler): - raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}") - - cls._formats[name] = { - "handler": handler_class, - "config": default_config.copy() - } - logging.debug(f"Registered quantization format: {name}") - - @classmethod - def get(cls, name, **override_config): - """ - Get format info with optional config overrides. - - Args: - name: Format identifier - **override_config: Configuration overrides - - Returns: - Dict with 'handler' (class) and 'config' (dict) keys - - Raises: - ValueError: If format name not registered - """ - if name not in cls._formats: - available = ", ".join(cls._formats.keys()) if cls._formats else "none" - raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}") - - format_info = cls._formats[name].copy() - # Merge override_config into default config - config = format_info["config"].copy() - config.update(override_config) - format_info["config"] = config - return format_info - - @classmethod - def list_formats(cls): - """List all registered format names""" - return list(cls._formats.keys()) - - @classmethod - def is_registered(cls, name): - """Check if a format is registered""" - return name in cls._formats - - -class FP8ScaledHandler(QuantFormatHandler): - """ - Handler for FP8 quantization with per-tensor scaling. - - Supports both weight-only and weight+activation quantization. - Compatible with existing fp8_linear implementation. - """ - - def setup_parameters(self): - """Initialize scale_weight and optionally scale_input""" - device = self.layer.weight.device - dtype = torch.float32 - - # Always have scale_weight for FP8 - if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None: - self.layer.scale_weight = torch.nn.Parameter( - torch.ones((), device=device, dtype=dtype), - requires_grad=False - ) - - # scale_input is optional (for activation quantization) - if self.config.get("quantize_activation", False): - if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None: - self.layer.scale_input = torch.nn.Parameter( - torch.ones((), device=device, dtype=dtype), - requires_grad=False - ) - else: - self.layer.scale_input = None - - def forward(self, *args, **kwargs): - """ - FP8 forward pass with optional activation quantization. - Supports Linear layers (Conv2d in future). - """ - # Detect layer type and dispatch - if isinstance(self.layer, torch.nn.Linear): - return self._forward_linear(*args, **kwargs) - else: - raise NotImplementedError( - f"FP8ScaledHandler not implemented for {type(self.layer).__name__}" - ) - - def _forward_linear(self, input): - """FP8 forward for Linear layers""" - # Try fast path with fp8_linear if enabled - if self.config.get("use_fp8_matmul", False) and not self.layer.training: - try: - result = fp8_linear(self.layer, input) - if result is not None: - return result - except Exception as e: - logging.debug(f"FP8 matmul failed, falling back to standard path: {e}") - - # Standard path: dequantize and compute - weight, bias = cast_bias_weight(self.layer, input) - - # Dequantize weight - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - - # Apply weight functions (LoRA, etc.) - they see dequantized weights - if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0: - weight = weight * scale - for f in self.layer.weight_function: - weight = f(weight) - else: - weight = weight * scale - - if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0: - for f in self.layer.bias_function: - bias = f(bias) if bias is not None else None - - # Execute linear operation - # Optimization: multiply by scale on smaller tensor - if weight.numel() < input.numel() and len(self.layer.weight_function) == 0: - return torch.nn.functional.linear(input, weight, bias) - else: - return torch.nn.functional.linear(input, weight, bias) - - def load_state_dict(self, state_dict, prefix): - """Load scale parameters from state dict""" - scale_weight_key = f"{prefix}scale_weight" - if scale_weight_key in state_dict: - self.layer.scale_weight.data.copy_(state_dict[scale_weight_key]) - - scale_input_key = f"{prefix}scale_input" - if scale_input_key in state_dict and self.layer.scale_input is not None: - self.layer.scale_input.data.copy_(state_dict[scale_input_key]) - - def state_dict(self, prefix): - """Save scale parameters to state dict""" - result = {f"{prefix}scale_weight": self.layer.scale_weight} - if self.layer.scale_input is not None: - result[f"{prefix}scale_input"] = self.layer.scale_input - return result - - def convert_weight(self, weight, inplace=False): - """Dequantize: multiply by scale""" - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - if inplace: - weight *= scale - return weight - return weight * scale - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): - """Quantize: divide by scale with stochastic rounding""" - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - quantized = comfy.float.stochastic_rounding( - weight / scale, - self.layer.weight.dtype, - seed=seed - ) - - if return_weight: - return quantized - - if inplace_update: - self.layer.weight.data.copy_(quantized) - else: - self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False) +# Import quantization operations from separate module +from .quant_ops import QuantizedTensorFP8 # ============================================================================== @@ -780,12 +506,13 @@ class MixedPrecisionOps(disable_weight_init): _layer_quant_config = {} # Class variable set by pick_operations() class Linear(disable_weight_init.Linear): - """Linear layer with optional per-layer quantization""" + """Linear layer with optional per-layer quantization using tensor subclasses""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.quant_handler = None - self._handler_initialized = False + self.quant_format = None + self.quant_scale = None + self._quantization_initialized = False def reset_parameters(self): # Don't allocate weights - return None like disable_weight_init @@ -795,9 +522,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """ Called by PyTorch during load_state_dict. - This is where we initialize the handler since we now know the layer name. + Load weight and wrap in QuantizedTensorFP8 if this layer is quantized. """ - if not self._handler_initialized: + # Call parent to load weight and bias first + super()._load_from_state_dict( + state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs + ) + + # After weight is loaded, wrap it if this layer is quantized + if not self._quantization_initialized: # Normalize layer name from prefix layer_name = prefix.rstrip('.') @@ -808,60 +542,78 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, break # Check if this layer has quantization config - # Access via parent class since _layer_quant_config is a class variable if layer_name in MixedPrecisionOps._layer_quant_config: config = MixedPrecisionOps._layer_quant_config[layer_name] - try: - format_info = QuantFormatRegistry.get( - config["format"], - **config.get("params", {}) - ) + self.quant_format = config.get("format", "fp8_e4m3fn") + + # Load scale parameter + scale_key = f"{prefix}scale_weight" + if scale_key in state_dict: + self.quant_scale = state_dict[scale_key] - # Initialize handler - self.quant_handler = format_info["handler"](self, **format_info["config"]) - self.quant_handler.setup_parameters() - - # Let handler load its parameters (scale_weight, etc.) - self.quant_handler.load_state_dict(state_dict, prefix) - - logging.debug(f"Initialized {config['format']} handler for layer {layer_name}") - except ValueError as e: - # Format not registered - fall back to standard precision - logging.warning( - f"Quantization format '{config['format']}' not registered for layer {layer_name}. " - f"Falling back to standard precision. Error: {e}" - ) - self.quant_handler = None - except Exception as e: - logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}") - self.quant_handler = None + # Wrap weight in QuantizedTensorFP8 + if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + try: + # Determine original dtype (default to bfloat16) + orig_dtype = torch.bfloat16 + + # Wrap weight in quantized tensor subclass + quantized_weight = QuantizedTensorFP8( + self.weight.data, + self.quant_scale, + orig_dtype=orig_dtype + ) + + # Replace weight parameter with wrapped version + self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + + logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})") + except Exception as e: + logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}") + self.quant_format = None + self.quant_scale = None + else: + logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization") + self.quant_format = None + self.quant_scale = None + else: + logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict") + self.quant_format = None - self._handler_initialized = True - - # Call parent to load weight and bias - super()._load_from_state_dict( - state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs - ) + self._quantization_initialized = True def _save_to_state_dict(self, destination, prefix, keep_vars): - """Save layer parameters including quantization metadata""" - super()._save_to_state_dict(destination, prefix, keep_vars) - - # Save handler parameters (scale_weight, etc.) - if self.quant_handler: - handler_dict = self.quant_handler.state_dict(prefix) - destination.update(handler_dict) + """Save layer parameters including quantization scale""" + # First unwrap the weight if it's quantized + if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8): + # Temporarily unwrap to save the raw FP8 data + quantized_tensor = self.weight.data + raw_fp8_data = quantized_tensor._raw_data + original_weight = self.weight + self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False) + + # Call parent to save unwrapped weight + super()._save_to_state_dict(destination, prefix, keep_vars) + + # Restore the wrapped weight + self.weight = original_weight + + # Save the scale parameter + if self.quant_scale is not None: + destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach() + else: + # Standard path for non-quantized weights + super()._save_to_state_dict(destination, prefix, keep_vars) def forward_comfy_cast_weights(self, input): - """Forward pass with optional quantization""" - if self.quant_handler: - # Use handler for quantized forward - return self.quant_handler.forward(input) - else: - # Standard path for non-quantized layers - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + """ + Forward pass - tensor subclass handles dispatch automatically! + __torch_dispatch__ will route to registered handlers based on tensor types. + """ + weight, bias = cast_bias_weight(self, input) + + # Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it! + return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): """Main forward pass""" @@ -933,35 +685,3 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ return disable_weight_init return manual_cast - - -# ============================================================================== -# Register built-in quantization formats -# ============================================================================== - -# FP8 E4M3FN weight-only quantization -QuantFormatRegistry.register( - "fp8_e4m3fn_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=False, - use_fp8_matmul=True, -) - -# FP8 E4M3FN weight+activation quantization -QuantFormatRegistry.register( - "fp8_e4m3fn_scaled_dynamic", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=True, - use_fp8_matmul=True, -) - -# FP8 E5M2 weight-only quantization -QuantFormatRegistry.register( - "fp8_e5m2_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e5m2, - quantize_activation=False, - use_fp8_matmul=True, -) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py new file mode 100644 index 000000000000..681eb9134935 --- /dev/null +++ b/comfy/quant_ops.py @@ -0,0 +1,346 @@ +import torch +import logging + +# ============================================================================== +# Global Operation Registry +# ============================================================================== + +# Global operation registry: torch operation → handler function +_QUANT_OP_REGISTRY = {} + +def register_quant_op(torch_op): + """ + Decorator to register an operation handler. + + Example: + @register_quant_op(torch.ops.aten.linear.default) + def handle_linear_fp8(func, args, kwargs): + # Implementation + ... + """ + def decorator(handler_func): + _QUANT_OP_REGISTRY[torch_op] = handler_func + return handler_func + return decorator + + +def get_quant_handler(torch_op): + """Get registered handler for an operation""" + return _QUANT_OP_REGISTRY.get(torch_op) + + +def list_registered_ops(): + """List all registered quantized operations""" + return list(_QUANT_OP_REGISTRY.keys()) + + +# ============================================================================== +# comfy_kitchen Integration +# ============================================================================== + +try: + import comfy_kitchen as ck + ck.disable_backend("cutile") + _CK_AVAILABLE = True + logging.info("comfy_kitchen available for optimized quantization kernels") +except ImportError: + ck = None + _CK_AVAILABLE = False + logging.info("comfy_kitchen not available - using PyTorch fallbacks") +except Exception as e: + ck = None + _CK_AVAILABLE = False + logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") + + +# ============================================================================== +# Quantized Tensor Subclass +# ============================================================================== + +class QuantizedTensorFP8(torch.Tensor): + """ + Tensor subclass for FP8 quantized data. + Automatically handles operations via __torch_dispatch__. + """ + + @staticmethod + def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16): + """ + Create a quantized FP8 tensor. + + Args: + tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2) + scale: Scale factor for dequantization (scalar tensor) + orig_dtype: Original dtype before quantization + """ + return torch.Tensor._make_subclass(cls, tensor, require_grad=False) + + def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): + self._scale = scale + self._orig_dtype = orig_dtype + # Store a reference to prevent infinite recursion in dequantize + self._raw_data = tensor + + def __repr__(self): + return (f"QuantizedTensorFP8(shape={self.shape}, " + f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + """ + Intercept ALL torch operations. + Routes to registered handlers or falls back to dequantization. + """ + kwargs = kwargs or {} + + # Special case: skip dispatch for internal tensor operations + # that are used for unwrapping (to avoid recursion) + if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]: + # For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach + if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8): + # Special handling for detach - return a new QuantizedTensorFP8 + qt = args[0] + detached_data = qt._raw_data.detach() + return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype) + + # For other ops, just unwrap + def unwrap(arg): + if isinstance(arg, QuantizedTensorFP8): + return arg._raw_data + return arg + new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args) + return func(*new_args, **kwargs) + + # Look up registered handler for this operation + handler = _QUANT_OP_REGISTRY.get(func) + if handler: + return handler(func, args, kwargs) + + # No handler - dequantize and use standard path + return cls._dequant_and_fallback(func, args, kwargs) + + @classmethod + def _dequant_and_fallback(cls, func, args, kwargs): + """Fallback: dequantize all quantized tensors""" + def dequant_arg(arg): + if isinstance(arg, QuantizedTensorFP8): + return arg.dequantize() + elif isinstance(arg, (list, tuple)): + return type(arg)(dequant_arg(a) for a in arg) + return arg + + new_args = dequant_arg(args) + new_kwargs = dequant_arg(kwargs) + return func(*new_args, **new_kwargs) + + def dequantize(self) -> torch.Tensor: + """Explicit dequantization""" + # Use the raw data and convert directly + # Call aten ops directly to minimize dispatch interference + plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) + # Multiply by scale + return plain_tensor * self._scale + + def detach(self): + """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" + # Detach the raw data and create a new QuantizedTensorFP8 + detached_data = self._raw_data.detach() + return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) + + +# ============================================================================== +# Operation Handlers for Quantized Tensors +# ============================================================================== + +@register_quant_op(torch.ops.aten.linear.default) +def handle_linear_fp8(func, args, kwargs): + """ + Handle F.linear() with quantized inputs. + + Supports: + - QuantizedTensorFP8 input + QuantizedTensorFP8 weight + - QuantizedTensorFP8 input + regular weight + - Regular input + QuantizedTensorFP8 weight + """ + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + # Case 1: Both input and weight are FP8 + if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): + # Use _scaled_mm for FP8×FP8 matmul + # Get plain tensors to avoid dispatch recursion + plain_input = input_tensor._raw_data + plain_weight = weight._raw_data + weight_t = plain_weight.t().contiguous() + + try: + if bias is not None: + output = torch._scaled_mm( + plain_input, + weight_t, + out_dtype=input_tensor._orig_dtype, + bias=bias, + scale_a=input_tensor._scale, + scale_b=weight._scale + ) + else: + output = torch._scaled_mm( + plain_input, + weight_t, + out_dtype=input_tensor._orig_dtype, + scale_a=input_tensor._scale, + scale_b=weight._scale + ) + + if isinstance(output, tuple): + output = output[0] + + # Check if output is FP8 (some architectures support this) + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + # Keep quantized! + output_scale = input_tensor._scale * weight._scale + return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) + else: + return output + except Exception as e: + logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + # Fall through to dequantization path + + # Case 2: Only weight is quantized + if isinstance(weight, QuantizedTensorFP8): + weight_dq = weight.dequantize() + input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor + return torch.nn.functional.linear(input_dq, weight_dq, bias) + + # Case 3: Only input is quantized + elif isinstance(input_tensor, QuantizedTensorFP8): + input_dq = input_tensor.dequantize() + return torch.nn.functional.linear(input_dq, weight, bias) + + # Case 4: Neither is quantized (shouldn't happen, but handle it) + else: + return torch.nn.functional.linear(input_tensor, weight, bias) + + +@register_quant_op(torch.ops.aten.silu.default) +def handle_silu_fp8(func, args, kwargs): + """ + SiLU can be computed approximately on FP8. + Keeps activations quantized for next layer. + """ + input_q = args[0] + + if not isinstance(input_q, QuantizedTensorFP8): + # Not quantized, use standard path + return torch.nn.functional.silu(input_q) + + # Compute SiLU while keeping quantized + # SiLU(x) = x * sigmoid(x) + + # Get plain tensor to avoid dispatch recursion + plain_tensor = input_q._raw_data + + # Upcast to FP16 for sigmoid stability + x_fp16 = plain_tensor.to(torch.float16) + sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale) + result_fp16 = x_fp16 * sigmoid_fp16 + + # Convert back to FP8 + result_fp8 = result_fp16.to(plain_tensor.dtype) + + # Return quantized (scale approximately preserved) + return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype) + + +@register_quant_op(torch.ops.aten.layer_norm.default) +def handle_layernorm_fp8(func, args, kwargs): + """ + LayerNorm requires high precision. + Dequantizes input and returns standard tensor. + """ + input_q = args[0] + normalized_shape = args[1] + weight = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + # Dequantize if needed + if isinstance(input_q, QuantizedTensorFP8): + x = input_q.dequantize() + else: + x = input_q + + # Standard LayerNorm + result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps) + + # Return dequantized (next layer will quantize if needed) + return result + + +@register_quant_op(torch.ops.aten.group_norm.default) +def handle_groupnorm_fp8(func, args, kwargs): + """ + GroupNorm requires high precision. + Dequantizes input and returns standard tensor. + """ + input_q = args[0] + num_groups = args[1] + weight = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + # Dequantize if needed + if isinstance(input_q, QuantizedTensorFP8): + x = input_q.dequantize() + else: + x = input_q + + # Standard GroupNorm + result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps) + + # Return dequantized + return result + + +@register_quant_op(torch.ops.aten.add.Tensor) +def handle_add_fp8(func, args, kwargs): + """ + Handle addition with mixed quantized/non-quantized tensors. + """ + a = args[0] + b = args[1] + + # If both are quantized, dequantize both + if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): + return a.dequantize() + b.dequantize() + # If only one is quantized, dequantize it + elif isinstance(a, QuantizedTensorFP8): + return a.dequantize() + b + elif isinstance(b, QuantizedTensorFP8): + return a + b.dequantize() + # Neither is quantized + else: + return a + b + + +@register_quant_op(torch.ops.aten.mul.Tensor) +def handle_mul_fp8(func, args, kwargs): + """ + Handle multiplication with mixed quantized/non-quantized tensors. + """ + a = args[0] + b = args[1] + + # If both are quantized, dequantize both + if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): + return a.dequantize() * b.dequantize() + # If only one is quantized, dequantize it + elif isinstance(a, QuantizedTensorFP8): + return a.dequantize() * b + elif isinstance(b, QuantizedTensorFP8): + return a * b.dequantize() + # Neither is quantized + else: + return a * b + diff --git a/tests-unit/comfy_test/test_mixed_precision.py b/tests-unit/comfy_test/test_mixed_precision.py new file mode 100644 index 000000000000..cbfa2866da4d --- /dev/null +++ b/tests-unit/comfy_test/test_mixed_precision.py @@ -0,0 +1,274 @@ +""" +End-to-end tests for mixed precision quantization. +Tests Phase 3: Mixed Precision Operations +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import ops + + +class SimpleModel(torch.nn.Module): + """Simple model for testing mixed precision""" + def __init__(self, operations=ops.disable_weight_init): + super().__init__() + self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + + def forward(self, x): + x = self.layer1(x) + x = torch.nn.functional.relu(x) + x = self.layer2(x) + x = torch.nn.functional.relu(x) + x = self.layer3(x) + return x + + +class TestMixedPrecisionOps(unittest.TestCase): + """Test MixedPrecisionOps end-to-end""" + + def test_all_layers_standard(self): + """Test that model with no quantization works normally""" + # Configure no quantization + ops.MixedPrecisionOps._layer_quant_config = {} + + # Create model + model = SimpleModel(operations=ops.MixedPrecisionOps) + + # Initialize weights manually + model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) + model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) + model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) + model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) + model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) + model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) + + # Initialize weight_function and bias_function + for layer in [model.layer1, model.layer2, model.layer3]: + layer.weight_function = [] + layer.bias_function = [] + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_mixed_precision_load(self): + """Test loading a mixed precision model from state dict""" + # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} # Disable for CPU testing + }, + "layer3": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict with mixed precision + fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2) + + state_dict = { + # Layer 1: FP8 E4M3FN + "layer1.weight": fp8_weight1, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + + # Layer 2: Standard BF16 + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + + # Layer 3: FP8 E5M2 + "layer3.weight": fp8_weight3, + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + "layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32), + } + + # Create model and load state dict + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Verify handlers are set up correctly + self.assertIsNotNone(model.layer1.quant_handler) + self.assertIsNone(model.layer2.quant_handler) # No quantization + self.assertIsNotNone(model.layer3.quant_handler) + + # Verify scales were loaded + self.assertEqual(model.layer1.scale_weight.item(), 2.0) + self.assertEqual(model.layer3.scale_weight.item(), 1.5) + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_state_dict_round_trip(self): + """Test saving and loading state dict preserves quantization""" + # Configure mixed precision + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict1 = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model1 = SimpleModel(operations=ops.MixedPrecisionOps) + model1.load_state_dict(state_dict1) + + # Save state dict + state_dict2 = model1.state_dict() + + # Verify scale_weight is saved + self.assertIn("layer1.scale_weight", state_dict2) + self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0) + + # Load into new model + model2 = SimpleModel(operations=ops.MixedPrecisionOps) + model2.load_state_dict(state_dict2) + + # Verify handler is set up + self.assertIsNotNone(model2.layer1.quant_handler) + self.assertEqual(model2.layer1.scale_weight.item(), 3.0) + + # Verify forward passes match + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output1 = model1(input_tensor) + output2 = model2(input_tensor) + + torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3) + + def test_weight_function_compatibility(self): + """Test that weight_function (LoRA) works with quantized layers""" + # Configure FP8 quantization + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Add a weight function (simulating LoRA) + # LoRA delta must match weight shape (20, 10) + def apply_lora(weight): + # Generate LoRA delta matching weight shape + lora_delta = torch.randn_like(weight) * 0.01 + return weight + lora_delta + + model.layer1.weight_function.append(apply_lora) + + # Forward pass should work with LoRA + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_error_handling_unknown_format(self): + """Test that unknown formats fall back gracefully""" + # Configure with unknown format + layer_quant_config = { + "layer1": { + "format": "unknown_format_xyz", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict + state_dict = { + "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + # Load should not crash, just log warning + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Handler should be None (fallback to standard) + self.assertIsNone(model.layer1.quant_handler) + + # Forward pass should still work + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + self.assertEqual(output.shape, (5, 40)) + + +class TestPickOperationsWithMixedPrecision(unittest.TestCase): + """Test pick_operations with mixed precision config""" + + def test_pick_operations_with_layer_quant_config(self): + """Test that pick_operations returns MixedPrecisionOps when config present""" + from comfy import supported_models_base + + # Create model config with layer_quant_config + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = { + "layer1": {"format": "fp8_e4m3fn_scaled", "params": {}} + } + + result = ops.pick_operations(None, None, model_config=model_config) + + self.assertEqual(result, ops.MixedPrecisionOps) + self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config) + + def test_pick_operations_without_layer_quant_config(self): + """Test that pick_operations falls back to standard when no config""" + from comfy import supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = None + + result = ops.pick_operations(None, None, model_config=model_config) + + self.assertEqual(result, ops.disable_weight_init) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_test/test_quant_detection.py b/tests-unit/comfy_test/test_quant_detection.py new file mode 100644 index 000000000000..bb952a81b3b2 --- /dev/null +++ b/tests-unit/comfy_test/test_quant_detection.py @@ -0,0 +1,262 @@ +""" +Integration tests for quantization detection. +Tests Phase 2: Detection & Integration +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import model_detection + + +class TestNormalizeLayerName(unittest.TestCase): + """Test the normalize_layer_name helper function""" + + def test_strip_prefix_and_suffix(self): + """Test stripping prefix and suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.weight", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_strip_multiple_prefixes(self): + """Test with multiple known prefixes""" + known_prefixes = ["model.diffusion_model.", "model.model.", "net."] + + result1 = model_detection.normalize_layer_name( + "model.diffusion_model.block.attn.weight", + known_prefixes + ) + self.assertEqual(result1, "block.attn") + + result2 = model_detection.normalize_layer_name( + "model.model.encoder.layer.weight", + known_prefixes + ) + self.assertEqual(result2, "encoder.layer") + + result3 = model_detection.normalize_layer_name( + "net.transformer.blocks.0.weight", + known_prefixes + ) + self.assertEqual(result3, "transformer.blocks.0") + + def test_strip_scale_weight_suffix(self): + """Test stripping scale_weight suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.scale_weight", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_strip_bias_suffix(self): + """Test stripping bias suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.bias", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_no_prefix_match(self): + """Test with no prefix match""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "other.model.layer1.weight", + known_prefixes + ) + # Should strip suffix but not prefix + self.assertEqual(result, "other.model.layer1") + + +class TestDetectLayerQuantization(unittest.TestCase): + """Test the detect_layer_quantization function""" + + def test_no_quantization(self): + """Test with no quantization markers""" + state_dict = { + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + "model.diffusion_model.layer2.weight": torch.randn(20, 30), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + self.assertIsNone(result) + + def test_legacy_scaled_fp8(self): + """Test that legacy scaled_fp8 marker returns None""" + # Create FP8 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn), + "model.diffusion_model.layer1.weight": fp8_weight, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + # Should return None to trigger legacy path + self.assertIsNone(result) + + def test_metadata_format(self): + """Test with new metadata format""" + metadata = { + "format_version": "1.0", + "layers": { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": True} + }, + "layer2": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": True} + } + } + } + state_dict = { + "model.diffusion_model._quantization_metadata": metadata, + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertIn("layer2", result) + self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") + self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled") + # Metadata should be popped from state_dict + self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict) + + def test_mixed_precision_detection(self): + """Test detection of mixed precision via scale patterns""" + # Create FP8 tensors by converting from float32 + fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + # Layer 1: FP8 (has scale_weight) + "model.diffusion_model.layer1.weight": fp8_weight1, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + # Layer 2: Standard (no scale_weight) + "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + # Layer 3: FP8 (has scale_weight) + "model.diffusion_model.layer3.weight": fp8_weight3, + "model.diffusion_model.layer3.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertIn("layer3", result) + self.assertNotIn("layer2", result) # Layer 2 not quantized + self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") + self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled") + + def test_all_layers_quantized(self): + """Test that uniform quantization (all layers) returns None""" + # Create FP8 tensors by converting from float32 + fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + # All layers have scale_weight + "model.diffusion_model.layer1.weight": fp8_weight1, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + "model.diffusion_model.layer2.weight": fp8_weight2, + "model.diffusion_model.layer2.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + # If all layers are quantized, it's not mixed precision + # Should return None to use legacy scaled_fp8_ops path + self.assertIsNone(result) + + def test_fp8_e5m2_detection(self): + """Test detection of FP8 E5M2 format""" + # Create FP8 E5M2 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2) + state_dict = { + "model.diffusion_model.layer1.weight": fp8_weight, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled") + + def test_invalid_metadata(self): + """Test with invalid metadata format""" + state_dict = { + "model.diffusion_model._quantization_metadata": "invalid_string", + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + # Should return None on invalid metadata + self.assertIsNone(result) + + def test_different_prefix(self): + """Test with different model prefix (audio model)""" + # Create FP8 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "model.model.layer1.weight": fp8_weight, + "model.model.layer1.scale_weight": torch.tensor(1.0), + "model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + } + result = model_detection.detect_layer_quantization(state_dict, "model.model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + + +class TestPickOperationsIntegration(unittest.TestCase): + """Test pick_operations with model_config parameter""" + + def test_backward_compatibility(self): + """Test that pick_operations works without model_config (legacy)""" + from comfy import ops + + # Should work without model_config parameter + result = ops.pick_operations(None, None) + self.assertIsNotNone(result) + self.assertEqual(result, ops.disable_weight_init) + + def test_with_model_config_no_quant(self): + """Test with model_config but no quantization""" + from comfy import ops, supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = None + + result = ops.pick_operations(None, None, model_config=model_config) + self.assertIsNotNone(result) + # Should use standard path + self.assertEqual(result, ops.disable_weight_init) + + def test_legacy_scaled_fp8(self): + """Test that legacy scaled_fp8 still works""" + from comfy import ops, supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.scaled_fp8 = torch.float8_e4m3fn + + result = ops.pick_operations( + None, None, + scaled_fp8=torch.float8_e4m3fn, + model_config=model_config + ) + self.assertIsNotNone(result) + # Should return scaled_fp8_ops (the returned class is the inner class) + # Check that it's not the standard disable_weight_init + self.assertNotEqual(result, ops.disable_weight_init) + # Verify it has Linear class + self.assertTrue(hasattr(result, 'Linear')) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_test/test_quant_registry.py b/tests-unit/comfy_test/test_quant_registry.py new file mode 100644 index 000000000000..5c624b1db9d8 --- /dev/null +++ b/tests-unit/comfy_test/test_quant_registry.py @@ -0,0 +1,399 @@ +""" +Unit tests for tensor subclass quantization system. +Tests the new QuantizedTensorFP8 subclass and operation handlers. +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import ops +from comfy import quant_ops + + +class TestQuantizedTensorFP8(unittest.TestCase): + """Test the QuantizedTensorFP8 tensor subclass""" + + def test_creation(self): + """Test creating a QuantizedTensorFP8""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) + + self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._scale, scale) + self.assertEqual(qt._orig_dtype, torch.bfloat16) + + def test_dequantize(self): + """Test explicit dequantization""" + # Create a simple FP8 tensor + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32) + dequantized = qt.dequantize() + + # Dequantized should be approximately ones * 3.0 + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_repr(self): + """Test string representation""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.5) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) + repr_str = repr(qt) + + self.assertIn("QuantizedTensorFP8", repr_str) + self.assertIn("shape", repr_str) + self.assertIn("scale", repr_str) + + +class TestOperationRegistry(unittest.TestCase): + """Test the operation registry system""" + + def test_registry_basics(self): + """Test that operations are registered""" + registered_ops = quant_ops.list_registered_ops() + + # Check that key operations are registered + self.assertIn(torch.ops.aten.linear.default, registered_ops) + self.assertIn(torch.ops.aten.silu.default, registered_ops) + self.assertIn(torch.ops.aten.layer_norm.default, registered_ops) + self.assertIn(torch.ops.aten.add.Tensor, registered_ops) + self.assertIn(torch.ops.aten.mul.Tensor, registered_ops) + + def test_get_handler(self): + """Test getting a registered handler""" + handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default) + self.assertIsNotNone(handler) + self.assertTrue(callable(handler)) + + def test_custom_registration(self): + """Test registering a custom operation""" + + # Define a custom handler + @quant_ops.register_quant_op(torch.ops.aten.relu.default) + def custom_relu_handler(func, args, kwargs): + return func(*args, **kwargs) + + # Verify registration + handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default) + self.assertIsNotNone(handler) + self.assertEqual(handler, custom_relu_handler) + + +class TestLinearHandler(unittest.TestCase): + """Test the linear operation handler""" + + def test_linear_with_quantized_weight(self): + """Test F.linear with quantized weight""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized weight + weight_fp32 = torch.randn(256, 128, dtype=torch.float32) + scale = torch.tensor(2.0) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) + + # Create input + input_tensor = torch.randn(16, 128, dtype=torch.float32) + + # Call linear (should trigger dispatch) + output = torch.nn.functional.linear(input_tensor, weight_q, bias=None) + + # Verify output shape + self.assertEqual(output.shape, (16, 256)) + + # Verify it's approximately correct (allowing for FP8 quantization error) + # Note: FP8 has limited precision, so use very loose tolerance + expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None) + # Just check that it's in the right ballpark (within 50% error on average) + mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large") + + def test_linear_with_bias(self): + """Test F.linear with quantized weight and bias""" + weight_fp32 = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) + + input_tensor = torch.randn(8, 32, dtype=torch.float32) + bias = torch.randn(64, dtype=torch.float32) + + output = torch.nn.functional.linear(input_tensor, weight_q, bias) + + self.assertEqual(output.shape, (8, 64)) + + +class TestActivationHandlers(unittest.TestCase): + """Test activation function handlers""" + + def test_silu_with_quantized_input(self): + """Test SiLU with quantized input""" + # Create quantized input + input_fp32 = torch.randn(16, 128, dtype=torch.float32) + scale = torch.tensor(1.0) + input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) + input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) + + # Apply SiLU + output = torch.nn.functional.silu(input_q) + + # Should return a QuantizedTensorFP8 + self.assertIsInstance(output, quant_ops.QuantizedTensorFP8) + + # Verify approximate correctness + expected = torch.nn.functional.silu(input_fp32) + output_dq = output.dequantize() + self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2)) + + def test_layernorm_dequantizes(self): + """Test that LayerNorm dequantizes input""" + # Create quantized input + input_fp32 = torch.randn(16, 128, dtype=torch.float32) + scale = torch.tensor(1.0) + input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) + input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) + + # Apply LayerNorm + weight = torch.ones(128) + bias = torch.zeros(128) + output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias) + + # Should NOT be quantized (LayerNorm breaks quantization) + self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8) + self.assertEqual(output.dtype, torch.float32) + + +class TestElementwiseHandlers(unittest.TestCase): + """Test element-wise operation handlers""" + + def test_add_mixed_tensors(self): + """Test addition with mixed quantized/non-quantized tensors""" + # Create quantized tensor + a_fp32 = torch.ones(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) + + # Non-quantized tensor + b = torch.ones(10, 20, dtype=torch.float32) * 2.0 + + # Add them + result = a_q + b + + # Should be dequantized + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_mul_quantized_tensors(self): + """Test multiplication of two quantized tensors""" + a_fp32 = torch.ones(10, 20) * 2.0 + scale_a = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32) + + b_fp32 = torch.ones(10, 20) * 3.0 + scale_b = torch.tensor(1.0) + b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn) + b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32) + + result = a_q * b_q + + # Should be dequantized + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +class TestMixedPrecisionOps(unittest.TestCase): + """Test MixedPrecisionOps integration""" + + def test_linear_layer_creation(self): + """Test that MixedPrecisionOps.Linear can be created""" + layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32) + + self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear) + self.assertFalse(layer._quantization_initialized) + self.assertIsNone(layer.quant_format) + + def test_layer_quant_config_detection(self): + """Test that layer quantization config is detected during load""" + # Set up layer config + ops.MixedPrecisionOps._layer_quant_config = { + "test_layer": { + "format": "fp8_e4m3fn", + "params": {} + } + } + + # Create a state dict with quantized weight + weight_fp32 = torch.randn(256, 128, dtype=torch.float32) + scale = torch.tensor(2.0) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + + state_dict = { + "model.diffusion_model.test_layer.weight": weight_fp8, + "model.diffusion_model.test_layer.scale_weight": scale, + } + + # Create layer and load + layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn) + layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn)) + + # Manually call _load_from_state_dict + layer._load_from_state_dict( + state_dict, + prefix="model.diffusion_model.test_layer.", + local_metadata={}, + strict=True, + missing_keys=[], + unexpected_keys=[], + error_msgs=[] + ) + + # Verify quantization was initialized + self.assertTrue(layer._quantization_initialized) + self.assertEqual(layer.quant_format, "fp8_e4m3fn") + self.assertIsNotNone(layer.quant_scale) + + # Verify weight is wrapped + self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8) + + # Clean up + ops.MixedPrecisionOps._layer_quant_config = {} + + +class TestBackwardCompatibility(unittest.TestCase): + """Test backward compatibility with legacy systems""" + + def test_legacy_ops_classes_exist(self): + """Test that legacy ops classes still exist""" + self.assertTrue(hasattr(ops, 'disable_weight_init')) + self.assertTrue(hasattr(ops, 'manual_cast')) + self.assertTrue(hasattr(ops, 'fp8_ops')) + self.assertTrue(hasattr(ops, 'scaled_fp8_ops')) + + def test_pick_operations_legacy_path(self): + """Test pick_operations returns correct class for legacy cases""" + # Test standard case + result = ops.pick_operations(torch.float32, torch.float32) + self.assertEqual(result, ops.disable_weight_init) + + # Test manual cast case + result = ops.pick_operations(torch.float32, torch.float16) + self.assertEqual(result, ops.manual_cast) + + +class TestFP8LinearUnification(unittest.TestCase): + """Test that fp8_linear now uses the unified tensor subclass infrastructure""" + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8") + def test_fp8_linear_uses_tensor_subclass(self): + """Verify fp8_linear wraps tensors in QuantizedTensorFP8""" + torch.manual_seed(42) + + # Create a mock Linear layer with FP8 weight + linear = ops.fp8_ops.Linear(4, 3, bias=True) + linear.weight = torch.nn.Parameter( + torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn), + requires_grad=False + ) + linear.bias = torch.nn.Parameter( + torch.randn(3, dtype=torch.bfloat16), + requires_grad=False + ) + linear.scale_weight = torch.tensor(1.0) + linear.scale_input = None # No input scaling + + # Create input + input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) + + # Call fp8_linear - should work without errors + try: + result = ops.fp8_linear(linear, input_tensor) + self.assertIsNotNone(result) + self.assertEqual(result.shape, (2, 3)) + except Exception as e: + # On CPU or unsupported hardware, _scaled_mm might not be available + # but the function should still complete without syntax errors + pass + + def test_fp8_linear_maintains_signature(self): + """Verify fp8_linear maintains its original function signature""" + import inspect + sig = inspect.signature(ops.fp8_linear) + params = list(sig.parameters.keys()) + + # Should have 'self' and 'input' parameters + self.assertIn('self', params) + self.assertIn('input', params) + self.assertEqual(len(params), 2) + + def test_fp8_linear_returns_none_for_non_fp8(self): + """Verify fp8_linear returns None for non-FP8 weights""" + # Create a Linear layer with BF16 weight (not FP8) + linear = ops.disable_weight_init.Linear(4, 3, bias=False) + linear.weight = torch.nn.Parameter( + torch.randn(3, 4, dtype=torch.bfloat16), + requires_grad=False + ) + + input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) + + # Should return None for non-FP8 weights + result = ops.fp8_linear(linear, input_tensor) + self.assertIsNone(result) + + def test_fp8_ops_linear_uses_fp8_linear(self): + """Verify fp8_ops.Linear still uses fp8_linear in forward pass""" + linear = ops.fp8_ops.Linear(4, 3, bias=False) + + # Verify the class has the forward_comfy_cast_weights method + self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights')) + + # The forward_comfy_cast_weights should attempt to call fp8_linear + # (we can't easily test this without mocking, but we verify structure) + import inspect + source = inspect.getsource(linear.forward_comfy_cast_weights) + self.assertIn('fp8_linear', source) + + +if __name__ == "__main__": + unittest.main() From 7ea731ea98445f12c07807102a1f2a4350952786 Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 22 Oct 2025 11:25:39 +0200 Subject: [PATCH 22/49] Fix FP8 MM --- comfy/ops.py | 14 +--- comfy/quant_ops.py | 205 +++++++++++---------------------------------- 2 files changed, 48 insertions(+), 171 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 2e6782dbd4f7..060b35137f21 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -370,19 +370,7 @@ def fp8_linear(self, input): # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) - - # Handle input quantization and wrapping - if self.scale_input is None: - # Clamp input to FP8 range and quantize - input = torch.clamp(input, min=-448, max=448, out=input) - input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous() - else: - # Apply inverse scale and quantize - input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - - # Wrap input in QuantizedTensorFP8 - quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype) - + quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! # This is the key unification: all FP8 computation goes through one path o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 681eb9134935..8e3bacbaf8af 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -79,18 +79,47 @@ def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): self._scale = scale self._orig_dtype = orig_dtype # Store a reference to prevent infinite recursion in dequantize - self._raw_data = tensor + self._raw_data = tensor.contiguous() def __repr__(self): return (f"QuantizedTensorFP8(shape={self.shape}, " f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + @classmethod + def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + + tensor_fp8 = None + if _CK_AVAILABLE: + try: + tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) + except Exception as e: + logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}") + + if tensor_fp8 is None: + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + + return cls(tensor_fp8, scale, orig_dtype=orig_dtype) + + @classmethod + def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn): + if strategy == "amax": + scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max + scale = scale.to(tensor.device, dtype=torch.float32) + else: + raise ValueError(f"Unknown quantization strategy: {strategy}. " + f"Supported: 'amax'") + + return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype) + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - """ - Intercept ALL torch operations. - Routes to registered handlers or falls back to dequantization. - """ kwargs = kwargs or {} # Special case: skip dispatch for internal tensor operations @@ -134,16 +163,11 @@ def dequant_arg(arg): return func(*new_args, **new_kwargs) def dequantize(self) -> torch.Tensor: - """Explicit dequantization""" - # Use the raw data and convert directly - # Call aten ops directly to minimize dispatch interference plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) - # Multiply by scale return plain_tensor * self._scale def detach(self): """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" - # Detach the raw data and create a new QuantizedTensorFP8 detached_data = self._raw_data.detach() return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) @@ -165,48 +189,35 @@ def handle_linear_fp8(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - + out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype) + # Case 1: Both input and weight are FP8 if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): - # Use _scaled_mm for FP8×FP8 matmul # Get plain tensors to avoid dispatch recursion plain_input = input_tensor._raw_data plain_weight = weight._raw_data - weight_t = plain_weight.t().contiguous() + weight_t = plain_weight.t() # Keep as column-major for cuBLASLt try: - if bias is not None: - output = torch._scaled_mm( - plain_input, - weight_t, - out_dtype=input_tensor._orig_dtype, - bias=bias, - scale_a=input_tensor._scale, - scale_b=weight._scale - ) - else: - output = torch._scaled_mm( - plain_input, - weight_t, - out_dtype=input_tensor._orig_dtype, - scale_a=input_tensor._scale, - scale_b=weight._scale - ) - + output = torch._scaled_mm( + plain_input, + weight_t, + bias=bias, + scale_a=input_tensor._scale, + scale_b=weight._scale, + out_dtype=out_dtype, + ) if isinstance(output, tuple): output = output[0] - # Check if output is FP8 (some architectures support this) if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - # Keep quantized! output_scale = input_tensor._scale * weight._scale - return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) + return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it else: return output except Exception as e: logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - # Fall through to dequantization path - + # Case 2: Only weight is quantized if isinstance(weight, QuantizedTensorFP8): weight_dq = weight.dequantize() @@ -222,125 +233,3 @@ def handle_linear_fp8(func, args, kwargs): else: return torch.nn.functional.linear(input_tensor, weight, bias) - -@register_quant_op(torch.ops.aten.silu.default) -def handle_silu_fp8(func, args, kwargs): - """ - SiLU can be computed approximately on FP8. - Keeps activations quantized for next layer. - """ - input_q = args[0] - - if not isinstance(input_q, QuantizedTensorFP8): - # Not quantized, use standard path - return torch.nn.functional.silu(input_q) - - # Compute SiLU while keeping quantized - # SiLU(x) = x * sigmoid(x) - - # Get plain tensor to avoid dispatch recursion - plain_tensor = input_q._raw_data - - # Upcast to FP16 for sigmoid stability - x_fp16 = plain_tensor.to(torch.float16) - sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale) - result_fp16 = x_fp16 * sigmoid_fp16 - - # Convert back to FP8 - result_fp8 = result_fp16.to(plain_tensor.dtype) - - # Return quantized (scale approximately preserved) - return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype) - - -@register_quant_op(torch.ops.aten.layer_norm.default) -def handle_layernorm_fp8(func, args, kwargs): - """ - LayerNorm requires high precision. - Dequantizes input and returns standard tensor. - """ - input_q = args[0] - normalized_shape = args[1] - weight = args[2] if len(args) > 2 else None - bias = args[3] if len(args) > 3 else None - eps = args[4] if len(args) > 4 else 1e-5 - - # Dequantize if needed - if isinstance(input_q, QuantizedTensorFP8): - x = input_q.dequantize() - else: - x = input_q - - # Standard LayerNorm - result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps) - - # Return dequantized (next layer will quantize if needed) - return result - - -@register_quant_op(torch.ops.aten.group_norm.default) -def handle_groupnorm_fp8(func, args, kwargs): - """ - GroupNorm requires high precision. - Dequantizes input and returns standard tensor. - """ - input_q = args[0] - num_groups = args[1] - weight = args[2] if len(args) > 2 else None - bias = args[3] if len(args) > 3 else None - eps = args[4] if len(args) > 4 else 1e-5 - - # Dequantize if needed - if isinstance(input_q, QuantizedTensorFP8): - x = input_q.dequantize() - else: - x = input_q - - # Standard GroupNorm - result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps) - - # Return dequantized - return result - - -@register_quant_op(torch.ops.aten.add.Tensor) -def handle_add_fp8(func, args, kwargs): - """ - Handle addition with mixed quantized/non-quantized tensors. - """ - a = args[0] - b = args[1] - - # If both are quantized, dequantize both - if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): - return a.dequantize() + b.dequantize() - # If only one is quantized, dequantize it - elif isinstance(a, QuantizedTensorFP8): - return a.dequantize() + b - elif isinstance(b, QuantizedTensorFP8): - return a + b.dequantize() - # Neither is quantized - else: - return a + b - - -@register_quant_op(torch.ops.aten.mul.Tensor) -def handle_mul_fp8(func, args, kwargs): - """ - Handle multiplication with mixed quantized/non-quantized tensors. - """ - a = args[0] - b = args[1] - - # If both are quantized, dequantize both - if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): - return a.dequantize() * b.dequantize() - # If only one is quantized, dequantize it - elif isinstance(a, QuantizedTensorFP8): - return a.dequantize() * b - elif isinstance(b, QuantizedTensorFP8): - return a * b.dequantize() - # Neither is quantized - else: - return a * b - From 4739d7717fea56750d0ef98c64268d9c1e487d78 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 23 Oct 2025 05:49:05 +1000 Subject: [PATCH 23/49] execution: fold in dependency aware caching / Fix --cache-none with loops/lazy etc (Resubmit) (#10440) * execution: fold in dependency aware caching This makes --cache-none compatiable with lazy and expanded subgraphs. Currently the --cache-none option is powered by the DependencyAwareCache. The cache attempts to maintain a parallel copy of the execution list data structure, however it is only setup once at the start of execution and does not get meaninigful updates to the execution list. This causes multiple problems when --cache-none is used with lazy and expanded subgraphs as the DAC does not accurately update its copy of the execution data structure. DAC has an attempt to handle subgraphs ensure_subcache however this does not accurately connect to nodes outside the subgraph. The current semantics of DAC are to free a node ASAP after the dependent nodes are executed. This means that if a subgraph refs such a node it will be requed and re-executed by the execution_list but DAC wont see it in its to-free lists anymore and leak memory. Rather than try and cover all the cases where the execution list changes from inside the cache, move the while problem to the executor which maintains an always up-to-date copy of the wanted data-structure. The executor now has a fast-moving run-local cache of its own. Each _to node has its own mini cache, and the cache is unconditionally primed at the time of add_strong_link. add_strong_link is called for all of static workflows, lazy links and expanded subgraphs so its the singular source of truth for output dependendencies. In the case of a cache-hit, the executor cache will hold the non-none value (it will respect updates if they happen somehow as well). In the case of a cache-miss, the executor caches a None and will wait for a notification to update the value when the node completes. When a node completes execution, it simply releases its mini-cache and in turn its strong refs on its direct anscestor outputs, allowing for ASAP freeing (same as the DependencyAwareCache but a little more automatic). This now allows for re-implementation of --cache-none with no cache at all. The dependency aware cache was also observing the dependency sematics for the objects and UI cache which is not accurate (this entire logic was always outputs specific). This also prepares for more complex caching strategies (such as RAM pressure based caching), where a cache can implement any freeing strategy completely independently of the DepedancyAwareness requirement. * main: re-implement --cache-none as no cache at all The execution list now tracks the dependency aware caching more correctly that the DependancyAwareCache. Change it to a cache that does nothing. * test_execution: add --cache-none to the test suite --cache-none is now expected to work universally. Run it through the full unit test suite. Propagate the server parameterization for whether or not the server is capabale of caching, so that the minority of tests that specifically check for cache hits can if else. Hard assert NOT caching in the else to give some coverage of --cache-none expected behaviour to not acutally cache. --- comfy_execution/caching.py | 174 ++++-------------------------- comfy_execution/graph.py | 32 +++++- execution.py | 34 +++--- main.py | 2 +- tests/execution/test_execution.py | 50 +++++---- 5 files changed, 102 insertions(+), 190 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 41224ce3b82e..566bc3f9c74a 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -265,6 +265,26 @@ async def ensure_subcache_for(self, node_id, children_ids): assert cache is not None return await cache._ensure_subcache(node_id, children_ids) +class NullCache: + + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + pass + + def all_node_ids(self): + return [] + + def clean_unused(self): + pass + + def get(self, node_id): + return None + + def set(self, node_id, value): + pass + + async def ensure_subcache_for(self, node_id, children_ids): + return self + class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -316,157 +336,3 @@ async def ensure_subcache_for(self, node_id, children_ids): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self - - -class DependencyAwareCache(BasicCache): - """ - A cache implementation that tracks dependencies between nodes and manages - their execution and caching accordingly. It extends the BasicCache class. - Nodes are removed from this cache once all of their descendants have been - executed. - """ - - def __init__(self, key_class): - """ - Initialize the DependencyAwareCache. - - Args: - key_class: The class used for generating cache keys. - """ - super().__init__(key_class) - self.descendants = {} # Maps node_id -> set of descendant node_ids - self.ancestors = {} # Maps node_id -> set of ancestor node_ids - self.executed_nodes = set() # Tracks nodes that have been executed - - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): - """ - Clear the entire cache and rebuild the dependency graph. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to initialize the cache for. - is_changed_cache: Flag indicating if the cache has changed. - """ - # Clear all existing cache data - self.cache.clear() - self.subcaches.clear() - self.descendants.clear() - self.ancestors.clear() - self.executed_nodes.clear() - - # Call the parent method to initialize the cache with the new prompt - await super().set_prompt(dynprompt, node_ids, is_changed_cache) - - # Rebuild the dependency graph - self._build_dependency_graph(dynprompt, node_ids) - - def _build_dependency_graph(self, dynprompt, node_ids): - """ - Build the dependency graph for all nodes. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to build the graph for. - """ - self.descendants.clear() - self.ancestors.clear() - for node_id in node_ids: - self.descendants[node_id] = set() - self.ancestors[node_id] = set() - - for node_id in node_ids: - inputs = dynprompt.get_node(node_id)["inputs"] - for input_data in inputs.values(): - if is_link(input_data): # Check if the input is a link to another node - ancestor_id = input_data[0] - self.descendants[ancestor_id].add(node_id) - self.ancestors[node_id].add(ancestor_id) - - def set(self, node_id, value): - """ - Mark a node as executed and store its value in the cache. - - Args: - node_id: The ID of the node to store. - value: The value to store for the node. - """ - self._set_immediate(node_id, value) - self.executed_nodes.add(node_id) - self._cleanup_ancestors(node_id) - - def get(self, node_id): - """ - Retrieve the cached value for a node. - - Args: - node_id: The ID of the node to retrieve. - - Returns: - The cached value for the node. - """ - return self._get_immediate(node_id) - - async def ensure_subcache_for(self, node_id, children_ids): - """ - Ensure a subcache exists for a node and update dependencies. - - Args: - node_id: The ID of the parent node. - children_ids: List of child node IDs to associate with the parent node. - - Returns: - The subcache object for the node. - """ - subcache = await super()._ensure_subcache(node_id, children_ids) - for child_id in children_ids: - self.descendants[node_id].add(child_id) - self.ancestors[child_id].add(node_id) - return subcache - - def _cleanup_ancestors(self, node_id): - """ - Check if ancestors of a node can be removed from the cache. - - Args: - node_id: The ID of the node whose ancestors are to be checked. - """ - for ancestor_id in self.ancestors.get(node_id, []): - if ancestor_id in self.executed_nodes: - # Remove ancestor if all its descendants have been executed - if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): - self._remove_node(ancestor_id) - - def _remove_node(self, node_id): - """ - Remove a node from the cache. - - Args: - node_id: The ID of the node to remove. - """ - cache_key = self.cache_key_set.get_data_key(node_id) - if cache_key in self.cache: - del self.cache[cache_key] - subcache_key = self.cache_key_set.get_subcache_key(node_id) - if subcache_key in self.subcaches: - del self.subcaches[subcache_key] - - def clean_unused(self): - """ - Clean up unused nodes. This is a no-op for this cache implementation. - """ - pass - - def recursive_debug_dump(self): - """ - Dump the cache and dependency graph for debugging. - - Returns: - A list containing the cache state and dependency graph. - """ - result = super().recursive_debug_dump() - result.append({ - "descendants": self.descendants, - "ancestors": self.ancestors, - "executed_nodes": list(self.executed_nodes), - }) - return result diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index f4b427265da7..341c9735d571 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -153,8 +153,9 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): - node_ids.append(from_node_id) + if (include_lazy or not is_lazy): + if not self.is_cached(from_node_id): + node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: @@ -194,10 +195,35 @@ def __init__(self, dynprompt, output_cache): super().__init__(dynprompt) self.output_cache = output_cache self.staged_node_id = None + self.execution_cache = {} + self.execution_cache_listeners = {} def is_cached(self, node_id): return self.output_cache.get(node_id) is not None + def cache_link(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + self.execution_cache[to_node_id] = {} + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + if not from_node_id in self.execution_cache_listeners: + self.execution_cache_listeners[from_node_id] = set() + self.execution_cache_listeners[from_node_id].add(to_node_id) + + def get_output_cache(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + return None + return self.execution_cache[to_node_id].get(from_node_id) + + def cache_update(self, node_id, value): + if node_id in self.execution_cache_listeners: + for to_node_id in self.execution_cache_listeners[node_id]: + if to_node_id in self.execution_cache: + self.execution_cache[to_node_id][node_id] = value + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + super().add_strong_link(from_node_id, from_socket, to_node_id) + self.cache_link(from_node_id, to_node_id) + async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): @@ -277,6 +303,8 @@ def unstage_node_execution(self): def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) + self.execution_cache.pop(node_id, None) + self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None def get_nodes_in_cycle(self): diff --git a/execution.py b/execution.py index 1dc35738b823..78c36a4b0556 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ BasicCache, CacheKeySetID, CacheKeySetInputSignature, - DependencyAwareCache, + NullCache, HierarchicalCache, LRUCache, ) @@ -91,13 +91,13 @@ async def get(self, node_id): class CacheType(Enum): CLASSIC = 0 LRU = 1 - DEPENDENCY_AWARE = 2 + NONE = 2 class CacheSet: def __init__(self, cache_type=None, cache_size=None): - if cache_type == CacheType.DEPENDENCY_AWARE: - self.init_dependency_aware_cache() + if cache_type == CacheType.NONE: + self.init_null_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.LRU: if cache_size is None: @@ -120,11 +120,12 @@ def init_lru_cache(self, cache_size): self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) - # only hold cached items while the decendents have not executed - def init_dependency_aware_cache(self): - self.outputs = DependencyAwareCache(CacheKeySetInputSignature) - self.ui = DependencyAwareCache(CacheKeySetInputSignature) - self.objects = DependencyAwareCache(CacheKeySetID) + def init_null_cache(self): + self.outputs = NullCache() + #The UI cache is expected to be iterable at the end of each workflow + #so it must cache at least a full workflow. Use Heirachical + self.ui = HierarchicalCache(CacheKeySetInputSignature) + self.objects = NullCache() def recursive_debug_dump(self): result = { @@ -135,7 +136,7 @@ def recursive_debug_dump(self): SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) if is_v3: valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) @@ -153,10 +154,10 @@ def mark_missing(): if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if outputs is None: + if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = outputs.get(input_unique_id) + cached_output = execution_list.get_output_cache(input_unique_id, unique_id) if cached_output is None: mark_missing() continue @@ -405,6 +406,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) get_progress_state().finish_progress(unique_id) + execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -434,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = caches.outputs.get(source_node)[source_output] + node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] for o in node_output: resolved_output.append(o) @@ -446,7 +448,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -549,11 +551,15 @@ async def await_completion(): subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) + execution_list.cache_link(node_id, unique_id) for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) + caches.outputs.set(unique_id, output_data) + execution_list.cache_update(unique_id, output_data) + except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") diff --git a/main.py b/main.py index 35857dba8a6b..4b4c5dcc4294 100644 --- a/main.py +++ b/main.py @@ -173,7 +173,7 @@ def prompt_worker(q, server_instance): if args.cache_lru > 0: cache_type = execution.CacheType.LRU elif args.cache_none: - cache_type = execution.CacheType.DEPENDENCY_AWARE + cache_type = execution.CacheType.NONE e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index ef73ad9fdd83..ace0d2279093 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -152,12 +152,12 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - # (use_lru, lru_size) - (False, 0), - (True, 0), - (True, 100), + { "extra_args" : [], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, + { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) - def _server(self, args_pytest, request): + def server(self, args_pytest, request): # Start server pargs = [ 'python','main.py', @@ -167,12 +167,10 @@ def _server(self, args_pytest, request): '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] - use_lru, lru_size = request.param - if use_lru: - pargs += ['--cache-lru', str(lru_size)] + pargs += [ str(param) for param in request.param["extra_args"] ] print("Running server with args:", pargs) # noqa: T201 p = subprocess.Popen(pargs) - yield + yield request.param p.kill() torch.cuda.empty_cache() @@ -193,7 +191,7 @@ def start_client(self, listen:str, port:int): return comfy_client @fixture(scope="class", autouse=True) - def shared_client(self, args_pytest, _server): + def shared_client(self, args_pytest, server): client = self.start_client(args_pytest["listen"], args_pytest["port"]) yield client del client @@ -225,7 +223,7 @@ def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): assert result.did_run(mask) assert result.did_run(lazy_mix) - def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -237,9 +235,12 @@ def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + else: + assert result2.did_run(node), f"Node {node_id} was cached, but should have been run" - def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -251,8 +252,12 @@ def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - assert not result2.did_run(input1), "Input1 should have been cached" - assert not result2.did_run(input2), "Input2 should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + else: + assert result2.did_run(input1), "Input1 should have been rerun" + assert result2.did_run(input2), "Input2 should have been rerun" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -411,7 +416,7 @@ def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder): input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) client.run(g) - def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -427,7 +432,10 @@ def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): result3 = client.run(g) result4 = client.run(g) assert result1.did_run(is_changed), "is_changed should have been run" - assert not result2.did_run(is_changed), "is_changed should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(is_changed), "is_changed should have been cached" + else: + assert result2.did_run(is_changed), "is_changed should have been re-run" assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" @@ -514,7 +522,7 @@ def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): assert len(images2) == 1, "Should have 1 image" # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): + def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -530,7 +538,11 @@ def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilde images = result.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - assert not result.did_run(test_node), "The execution should have been cached" + if server["should_cache_results"]: + assert not result.did_run(test_node), "The execution should have been cached" + else: + assert result.did_run(test_node), "The execution should have been re-run" + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized From a1864c01f29cc43fe6bf823fc3fd46ba2781c2e0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 22 Oct 2025 14:26:22 -0700 Subject: [PATCH 24/49] Small readme improvement. (#10442) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c9a0644e33b5..434d4ff06543 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,8 @@ Python 3.14 will work if you comment out the `kornia` dependency in the requirem Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 +### Instructions: + Git clone this repo. Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints From 1bcda6df987a6c92b39d8b6d29e0b029450d67d0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 23 Oct 2025 18:21:14 -0700 Subject: [PATCH 25/49] WIP way to support multi multi dimensional latents. (#10456) --- comfy/model_base.py | 10 ++++- comfy/nested_tensor.py | 91 ++++++++++++++++++++++++++++++++++++++++++ comfy/sample.py | 27 ++++++++++--- comfy/samplers.py | 23 +++++++---- comfy/utils.py | 22 ++++++++++ 5 files changed, 158 insertions(+), 15 deletions(-) create mode 100644 comfy/nested_tensor.py diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea192..e877f19ac6c2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -197,8 +197,14 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds) - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() - return self.model_sampling.calculate_denoised(sigma, model_output, x) + if "latent_shapes" in extra_conds: + xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) + if len(model_output) > 1 and not torch.is_tensor(model_output): + model_output, _ = utils.pack_latents(model_output) + + return self.model_sampling.calculate_denoised(sigma, model_output.float(), x) def process_timestep(self, timestep, **kwargs): return timestep diff --git a/comfy/nested_tensor.py b/comfy/nested_tensor.py new file mode 100644 index 000000000000..b700816fabcb --- /dev/null +++ b/comfy/nested_tensor.py @@ -0,0 +1,91 @@ +import torch + +class NestedTensor: + def __init__(self, tensors): + self.tensors = list(tensors) + self.is_nested = True + + def _copy(self): + return NestedTensor(self.tensors) + + def apply_operation(self, other, operation): + o = self._copy() + if isinstance(other, NestedTensor): + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other.tensors[i]) + else: + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other) + return o + + def __add__(self, b): + return self.apply_operation(b, lambda x, y: x + y) + + def __sub__(self, b): + return self.apply_operation(b, lambda x, y: x - y) + + def __mul__(self, b): + return self.apply_operation(b, lambda x, y: x * y) + + # def __itruediv__(self, b): + # return self.apply_operation(b, lambda x, y: x / y) + + def __truediv__(self, b): + return self.apply_operation(b, lambda x, y: x / y) + + def __getitem__(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs)) + + def unbind(self): + return self.tensors + + def to(self, *args, **kwargs): + o = self._copy() + for i, t in enumerate(o.tensors): + o.tensors[i] = t.to(*args, **kwargs) + return o + + def new_ones(self, *args, **kwargs): + return self.tensors[0].new_ones(*args, **kwargs) + + def float(self): + return self.to(dtype=torch.float) + + def chunk(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs)) + + def size(self): + return self.tensors[0].size() + + @property + def shape(self): + return self.tensors[0].shape + + @property + def ndim(self): + dims = 0 + for t in self.tensors: + dims = max(t.ndim, dims) + return dims + + @property + def device(self): + return self.tensors[0].device + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def layout(self): + return self.tensors[0].layout + + +def cat_nested(tensors, *args, **kwargs): + cated_tensors = [] + for i in range(len(tensors[0].tensors)): + tens = [] + for j in range(len(tensors)): + tens.append(tensors[j].tensors[i]) + cated_tensors.append(torch.cat(tens, *args, **kwargs)) + return NestedTensor(cated_tensors) diff --git a/comfy/sample.py b/comfy/sample.py index be5a7e246fdf..b1395da84ae0 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -4,13 +4,9 @@ import comfy.utils import numpy as np import logging +import comfy.nested_tensor -def prepare_noise(latent_image, seed, noise_inds=None): - """ - creates random noise given a latent image and a seed. - optional arg skip can be used to skip and discard x number of noise generations for a given seed - """ - generator = torch.manual_seed(seed) +def prepare_noise_inner(latent_image, generator, noise_inds=None): if noise_inds is None: return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") @@ -22,9 +18,28 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises.append(noise) noises = [noises[i] for i in inverse] noises = torch.cat(noises, axis=0) + +def prepare_noise(latent_image, seed, noise_inds=None): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ + generator = torch.manual_seed(seed) + + if latent_image.is_nested: + tensors = latent_image.unbind() + noises = [] + for t in tensors: + noises.append(prepare_noise_inner(t, generator, noise_inds)) + noises = comfy.nested_tensor.NestedTensor(noises) + else: + noises = prepare_noise_inner(latent_image, generator, noise_inds) + return noises def fix_empty_latent_channels(model, latent_image): + if latent_image.is_nested: + return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) diff --git a/comfy/samplers.py b/comfy/samplers.py index e7efaf4705d3..fa4640842bbe 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -782,7 +782,7 @@ def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, * return KSAMPLER(sampler_function, extra_options, inpaint_options) -def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None): for k in conds: conds[k] = conds[k][:] resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) @@ -792,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N if hasattr(model, 'extra_conds'): for k in conds: - conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes) #make sure each cond area has an opposite one with the same area for k in conds: @@ -962,11 +962,11 @@ def outer_predict_noise(self, x, timestep, model_options={}, seed=None): def predict_noise(self, x, timestep, model_options={}, seed=None): return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) - def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None): if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = self.inner_model.process_latent_in(latent_image) - self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas @@ -980,7 +980,7 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None): self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device @@ -994,7 +994,7 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, try: self.model_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: self.model_patcher.cleanup() @@ -1007,6 +1007,12 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba if sigmas.shape[-1] == 0: return latent_image + if latent_image.is_nested: + latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind()) + noise, _ = comfy.utils.pack_latents(noise.unbind()) + else: + latent_shapes = [latent_image.shape] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) @@ -1026,7 +1032,7 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options @@ -1034,6 +1040,9 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self.model_patcher.restore_hook_patches() del self.conds + + if len(latent_shapes) > 1: + output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes)) return output diff --git a/comfy/utils.py b/comfy/utils.py index 0fd03f165b7c..4bd281057995 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1106,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): dim=1 ) return out + +def pack_latents(latents): + latent_shapes = [] + tensors = [] + for tensor in latents: + latent_shapes.append(tensor.shape) + tensors.append(tensor.reshape(tensor.shape[0], 1, -1)) + + latent = torch.cat(tensors, dim=-1) + return latent, latent_shapes + +def unpack_latents(combined_latent, latent_shapes): + if len(latent_shapes) > 1: + output_tensors = [] + for shape in latent_shapes: + cut = math.prod(shape[1:]) + tens = combined_latent[:, :, :cut] + combined_latent = combined_latent[:, :, cut:] + output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) + else: + output_tensors = combined_latent + return output_tensors From 24188b3141aace272cb91b85578c76f5a8f70e1c Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 24 Oct 2025 13:36:30 +0800 Subject: [PATCH 26/49] Update template to 0.2.2 (#10461) Fix template typo issue --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dd2afcab0564..8570c66b6fae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.7 -comfyui-workflow-templates==0.2.1 +comfyui-workflow-templates==0.2.2 comfyui-embedded-docs==0.3.0 torch torchsde From 388b306a2b48070737b092b51e76de933baee9ad Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 24 Oct 2025 08:37:16 +0300 Subject: [PATCH 27/49] feat(api-nodes): network client v2: async ops, cancellation, downloads, refactor (#10390) * feat(api-nodes): implement new API client for V3 nodes * feat(api-nodes): implement new API client for V3 nodes * feat(api-nodes): implement new API client for V3 nodes * converted WAN nodes to use new client; polishing * fix(auth): do not leak authentification for the absolute urls * convert BFL API nodes to use new API client; remove deprecated BFL nodes * converted Google Veo nodes * fix(Veo3.1 model): take into account "generate_audio" parameter --- comfy_api_nodes/apinode_utils.py | 435 +--------- comfy_api_nodes/apis/bfl_api.py | 51 +- comfy_api_nodes/apis/veo_api.py | 111 +++ comfy_api_nodes/nodes_bfl.py | 605 ++++---------- comfy_api_nodes/nodes_bytedance.py | 277 ++----- comfy_api_nodes/nodes_gemini.py | 5 +- comfy_api_nodes/nodes_kling.py | 350 +++----- comfy_api_nodes/nodes_luma.py | 2 +- comfy_api_nodes/nodes_minimax.py | 2 +- comfy_api_nodes/nodes_moonvalley.py | 366 ++------- comfy_api_nodes/nodes_openai.py | 4 +- comfy_api_nodes/nodes_pika.py | 6 +- comfy_api_nodes/nodes_pixverse.py | 13 +- comfy_api_nodes/nodes_recraft.py | 4 +- comfy_api_nodes/nodes_runway.py | 199 ++--- comfy_api_nodes/nodes_sora.py | 74 +- comfy_api_nodes/nodes_stability.py | 8 +- comfy_api_nodes/nodes_veo2.py | 176 ++-- comfy_api_nodes/nodes_vidu.py | 129 +-- comfy_api_nodes/nodes_wan.py | 245 +++--- comfy_api_nodes/util/__init__.py | 87 ++ comfy_api_nodes/util/_helpers.py | 71 ++ comfy_api_nodes/util/client.py | 941 ++++++++++++++++++++++ comfy_api_nodes/util/common_exceptions.py | 14 + comfy_api_nodes/util/conversions.py | 407 ++++++++++ comfy_api_nodes/util/download_helpers.py | 249 ++++++ comfy_api_nodes/util/upload_helpers.py | 338 ++++++++ comfy_api_nodes/util/validation_utils.py | 58 +- pyproject.toml | 2 + 29 files changed, 2933 insertions(+), 2296 deletions(-) create mode 100644 comfy_api_nodes/apis/veo_api.py create mode 100644 comfy_api_nodes/util/_helpers.py create mode 100644 comfy_api_nodes/util/client.py create mode 100644 comfy_api_nodes/util/common_exceptions.py create mode 100644 comfy_api_nodes/util/conversions.py create mode 100644 comfy_api_nodes/util/download_helpers.py create mode 100644 comfy_api_nodes/util/upload_helpers.py diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index bc3d2d07e6a5..e3d2820592cd 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,15 +1,10 @@ from __future__ import annotations import aiohttp -import io -import logging import mimetypes -import os from typing import Optional, Union from comfy.utils import common_upscale -from comfy_api.input_impl import VideoFromFile from comfy_api.util import VideoContainer, VideoCodec from comfy_api.input.video_types import VideoInput -from comfy_api.input.basic_types import AudioInput from comfy_api_nodes.apis.client import ( ApiClient, ApiEndpoint, @@ -26,43 +21,8 @@ import torch import math import base64 -import uuid +from .util import tensor_to_bytesio, bytesio_to_image_tensor from io import BytesIO -import av - - -async def download_url_to_video_output( - video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None -) -> VideoFromFile: - """Downloads a video from a URL and returns a `VIDEO` output. - - Args: - video_url: The URL of the video to download. - - Returns: - A Comfy node `VIDEO` output. - """ - video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs) - if video_io is None: - error_msg = f"Failed to download video from {video_url}" - logging.error(error_msg) - raise ValueError(error_msg) - return VideoFromFile(video_io) - - -def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: - """Downscale input image tensor to roughly the specified total pixels.""" - samples = image.movedim(-1, 1) - total = int(total_pixels) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: - return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1, -1) - return s async def validate_and_cast_response( @@ -162,11 +122,6 @@ def validate_aspect_ratio( return aspect_ratio -def mimetype_to_extension(mime_type: str) -> str: - """Converts a MIME type to a file extension.""" - return mime_type.split("/")[-1].lower() - - async def download_url_to_bytesio( url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None ) -> BytesIO: @@ -195,136 +150,11 @@ async def download_url_to_bytesio( return BytesIO(await resp.read()) -def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: - """Converts image data from BytesIO to a torch.Tensor. - - Args: - image_bytesio: BytesIO object containing the image data. - mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). - - Returns: - A torch.Tensor representing the image (1, H, W, C). - - Raises: - PIL.UnidentifiedImageError: If the image data cannot be identified. - ValueError: If the specified mode is invalid. - """ - image = Image.open(image_bytesio) - image = image.convert(mode) - image_array = np.array(image).astype(np.float32) / 255.0 - return torch.from_numpy(image_array).unsqueeze(0) - - -async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: - """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" - image_bytesio = await download_url_to_bytesio(url, timeout) - return bytesio_to_image_tensor(image_bytesio) - - def process_image_response(response_content: bytes | str) -> torch.Tensor: """Uses content from a Response object and converts it to a torch.Tensor""" return bytesio_to_image_tensor(BytesIO(response_content)) -def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: - """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" - if len(image.shape) > 3: - image = image[0] - # TODO: remove alpha if not allowed and present - input_tensor = image.cpu() - input_tensor = downscale_image_tensor( - input_tensor.unsqueeze(0), total_pixels=total_pixels - ).squeeze() - image_np = (input_tensor.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - return img - - -def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: - """Converts a PIL Image to a BytesIO object.""" - if not mime_type: - mime_type = "image/png" - - img_byte_arr = io.BytesIO() - # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') - pil_format = mime_type.split("/")[-1].upper() - if pil_format == "JPG": - pil_format = "JPEG" - img.save(img_byte_arr, format=pil_format) - img_byte_arr.seek(0) - return img_byte_arr - - -def tensor_to_bytesio( - image: torch.Tensor, - name: Optional[str] = None, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> BytesIO: - """Converts a torch.Tensor image to a named BytesIO object. - - Args: - image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Named BytesIO object containing the image data, with pointer set to the start of buffer. - """ - if not mime_type: - mime_type = "image/png" - - pil_image = _tensor_to_pil(image, total_pixels=total_pixels) - img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = ( - f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" - ) - return img_binary - - -def tensor_to_base64_string( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Base64 encoded string of the image. - """ - pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels) - img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_bytes = img_byte_arr.getvalue() - # Encode bytes to base64 string - base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") - return base64_encoded_string - - -def tensor_to_data_uri( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Converts a tensor image to a Data URI string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). - - Returns: - Data URI string (e.g., 'data:image/png;base64,...'). - """ - base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) - return f"data:{mime_type};base64,{base64_string}" - - def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: @@ -392,7 +222,7 @@ def video_to_base64_string( container_format: Optional container format to use (defaults to video.container if available) codec: Optional codec to use (defaults to video.codec if available) """ - video_bytes_io = io.BytesIO() + video_bytes_io = BytesIO() # Use provided format/codec if specified, otherwise use video's own if available format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) @@ -403,214 +233,6 @@ def video_to_base64_string( return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") -async def upload_video_to_comfyapi( - video: VideoInput, - auth_kwargs: Optional[dict[str, str]] = None, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, -) -> str: - """ - Uploads a single video to ComfyUI API and returns its download URL. - Uses the specified container and codec for saving the video before upload. - - Args: - video: VideoInput object (Comfy VIDEO type). - auth_kwargs: Optional authentication token(s). - container: The video container format to use (default: MP4). - codec: The video codec to use (default: H264). - max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised. - - Returns: - The download URL for the uploaded video file. - """ - if max_duration is not None: - try: - actual_duration = video.duration_seconds - if actual_duration is not None and actual_duration > max_duration: - raise ValueError( - f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." - ) - except Exception as e: - logging.error("Error getting video duration: %s", str(e)) - raise ValueError(f"Could not verify video duration from source: {e}") from e - - upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" - - # Convert VideoInput to BytesIO using specified container/codec - video_bytes_io = io.BytesIO() - video.save_to(video_bytes_io, format=container, codec=codec) - video_bytes_io.seek(0) - - return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs) - - -def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: - """ - Prepares audio waveform for av library by converting to a contiguous numpy array. - - Args: - waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. - - Returns: - Contiguous numpy array of the audio waveform. If the audio was batched, - the first item is taken. - """ - if waveform.ndim != 3 or waveform.shape[0] != 1: - raise ValueError("Expected waveform tensor shape (1, channels, samples)") - - # If batch is > 1, take first item - if waveform.shape[0] > 1: - waveform = waveform[0] - - # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array - audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() - if audio_data_np.dtype != np.float32: - audio_data_np = audio_data_np.astype(np.float32) - - return audio_data_np - - -def audio_ndarray_to_bytesio( - audio_data_np: np.ndarray, - sample_rate: int, - container_format: str = "mp4", - codec_name: str = "aac", -) -> BytesIO: - """ - Encodes a numpy array of audio data into a BytesIO object. - """ - audio_bytes_io = io.BytesIO() - with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: - audio_stream = output_container.add_stream(codec_name, rate=sample_rate) - frame = av.AudioFrame.from_ndarray( - audio_data_np, - format="fltp", - layout="stereo" if audio_data_np.shape[0] > 1 else "mono", - ) - frame.sample_rate = sample_rate - frame.pts = 0 - - for packet in audio_stream.encode(frame): - output_container.mux(packet) - - # Flush stream - for packet in audio_stream.encode(None): - output_container.mux(packet) - - audio_bytes_io.seek(0) - return audio_bytes_io - - -async def upload_audio_to_comfyapi( - audio: AudioInput, - auth_kwargs: Optional[dict[str, str]] = None, - container_format: str = "mp4", - codec_name: str = "aac", - mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", -) -> str: - """ - Uploads a single audio input to ComfyUI API and returns its download URL. - Encodes the raw waveform into the specified format before uploading. - - Args: - audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate) - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded audio file. - """ - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - - return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) - - -def f32_pcm(wav: torch.Tensor) -> torch.Tensor: - """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" - if wav.dtype.is_floating_point: - return wav - elif wav.dtype == torch.int16: - return wav.float() / (2 ** 15) - elif wav.dtype == torch.int32: - return wav.float() / (2 ** 31) - raise ValueError(f"Unsupported wav dtype: {wav.dtype}") - - -def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict: - """ - Decode any common audio container from bytes using PyAV and return - a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. - """ - with av.open(io.BytesIO(audio_bytes)) as af: - if not af.streams.audio: - raise ValueError("No audio stream found in response.") - stream = af.streams.audio[0] - - in_sr = int(stream.codec_context.sample_rate) - out_sr = in_sr - - frames: list[torch.Tensor] = [] - n_channels = stream.channels or 1 - - for frame in af.decode(streams=stream.index): - arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] - buf = torch.from_numpy(arr) - if buf.ndim == 1: - buf = buf.unsqueeze(0) # [T] -> [1, T] - elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: - buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] - elif buf.shape[0] != n_channels: - buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] - frames.append(buf) - - if not frames: - raise ValueError("Decoded zero audio frames.") - - wav = torch.cat(frames, dim=1) # [C, T] - wav = f32_pcm(wav) - return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} - - -def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO: - waveform = audio["waveform"].cpu() - - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format="mp3") - - out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) - out_stream.bit_rate = 320000 - - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') - frame.sample_rate = audio["sample_rate"] - frame.pts = 0 - output_container.mux(out_stream.encode(frame)) - output_container.mux(out_stream.encode(None)) - output_container.close() - output_buffer.seek(0) - return output_buffer - - -def audio_to_base64_string( - audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" -) -> str: - """Converts an audio input to a base64 string.""" - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - audio_bytes = audio_bytes_io.getvalue() - return base64.b64encode(audio_bytes).decode("utf-8") - - async def upload_images_to_comfyapi( image: torch.Tensor, max_images=8, @@ -663,56 +285,3 @@ def resize_mask_to_image( if not allow_gradient: mask = (mask > 0.5).float() return mask - - -def validate_string( - string: str, - strip_whitespace=True, - field_name="prompt", - min_length=None, - max_length=None, -): - if string is None: - raise Exception(f"Field '{field_name}' cannot be empty.") - if strip_whitespace: - string = string.strip() - if min_length and len(string) < min_length: - raise Exception( - f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." - ) - if max_length and len(string) > max_length: - raise Exception( - f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." - ) - - -def image_tensor_pair_to_batch( - image1: torch.Tensor, image2: torch.Tensor -) -> torch.Tensor: - """ - Converts a pair of image tensors to a batch tensor. - If the images are not the same size, the smaller image is resized to - match the larger image. - """ - if image1.shape[1:] != image2.shape[1:]: - image2 = common_upscale( - image2.movedim(-1, 1), - image1.shape[2], - image1.shape[1], - "bilinear", - "center", - ).movedim(1, -1) - return torch.cat((image1, image2), dim=0) - - -def get_size(path_or_object: Union[str, io.BytesIO]) -> int: - if isinstance(path_or_object, str): - return os.path.getsize(path_or_object) - return len(path_or_object.getvalue()) - - -def validate_container_format_is_mp4(video: VideoInput) -> None: - """Validates video container format is MP4.""" - container_format = video.get_container_format() - if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: - raise ValueError(f"Only MP4 container format supported. Got: {container_format}") diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index 0e90aef7c681..0fc8c060767c 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel): mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') -class BFLFluxCannyImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection') - canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection') - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - -class BFLFluxDepthImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - class BFLFluxProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for image generation.') prompt_upsampling: Optional[bool] = Field( @@ -160,15 +122,8 @@ class BFLStatus(str, Enum): error = "Error" -class BFLFluxProStatusResponse(BaseModel): +class BFLFluxStatusResponse(BaseModel): id: str = Field(..., description="The unique identifier for the generation task.") status: BFLStatus = Field(..., description="The status of the task.") - result: Optional[Dict[str, Any]] = Field( - None, description="The result of the task (null if not completed)." - ) - progress: confloat(ge=0.0, le=1.0) = Field( - ..., description="The progress of the task (0.0 to 1.0)." - ) - details: Optional[Dict[str, Any]] = Field( - None, description="Additional details about the task (null if not available)." - ) + result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") + progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py new file mode 100644 index 000000000000..a55137afbec1 --- /dev/null +++ b/comfy_api_nodes/apis/veo_api.py @@ -0,0 +1,111 @@ +from typing import Optional, Union +from enum import Enum + +from pydantic import BaseModel, Field + + +class Image2(BaseModel): + bytesBase64Encoded: str + gcsUri: Optional[str] = None + mimeType: Optional[str] = None + + +class Image3(BaseModel): + bytesBase64Encoded: Optional[str] = None + gcsUri: str + mimeType: Optional[str] = None + + +class Instance1(BaseModel): + image: Optional[Union[Image2, Image3]] = Field( + None, description='Optional image to guide video generation' + ) + prompt: str = Field(..., description='Text description of the video') + + +class PersonGeneration1(str, Enum): + ALLOW = 'ALLOW' + BLOCK = 'BLOCK' + + +class Parameters1(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + generateAudio: Optional[bool] = Field( + None, + description='Generate audio for the video. Only supported by veo 3 models.', + ) + negativePrompt: Optional[str] = None + personGeneration: Optional[PersonGeneration1] = None + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' + ) + + +class VeoGenVidRequest(BaseModel): + instances: Optional[list[Instance1]] = None + parameters: Optional[Parameters1] = None + + +class VeoGenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], + ) + + +class VeoGenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Video(BaseModel): + bytesBase64Encoded: Optional[str] = Field( + None, description='Base64-encoded video content' + ) + gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video') + mimeType: Optional[str] = Field(None, description='Video MIME type') + + +class Error1(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + + +class Response1(BaseModel): + field_type: Optional[str] = Field( + None, + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], + ) + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' + ) + raiMediaFilteredReasons: Optional[list[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' + ) + videos: Optional[list[Video]] = None + + +class VeoGenVidPollResponse(BaseModel): + done: Optional[bool] = None + error: Optional[Error1] = Field( + None, description='Error details if operation failed' + ) + name: Optional[str] = None + response: Optional[Response1] = Field( + None, description='The actual prediction response if done is true' + ) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index b6cc90f05c7a..baa74fd529d8 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,136 +1,43 @@ -import asyncio -import io from inspect import cleandoc -from typing import Union, Optional +from typing import Optional + +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apinode_utils import ( + resize_mask_to_image, + validate_aspect_ratio, +) from comfy_api_nodes.apis.bfl_api import ( - BFLStatus, BFLFluxExpandImageRequest, BFLFluxFillImageRequest, - BFLFluxCannyImageRequest, - BFLFluxDepthImageRequest, - BFLFluxProGenerateRequest, BFLFluxKontextProGenerateRequest, - BFLFluxProUltraGenerateRequest, + BFLFluxProGenerateRequest, BFLFluxProGenerateResponse, + BFLFluxProUltraGenerateRequest, + BFLFluxStatusResponse, + BFLStatus, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, - validate_aspect_ratio, - process_image_response, - resize_mask_to_image, + download_url_to_image_tensor, + poll_op, + sync_op, + tensor_to_base64_string, validate_string, ) -import numpy as np -from PIL import Image -import aiohttp -import torch -import base64 -import time -from server import PromptServer - def convert_mask_to_image(mask: torch.Tensor): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ mask = mask.unsqueeze(-1) - mask = torch.cat([mask]*3, dim=-1) + mask = torch.cat([mask] * 3, dim=-1) return mask -async def handle_bfl_synchronous_operation( - operation: SynchronousOperation, - timeout_bfl_calls=360, - node_id: Union[str, None] = None, -): - response_api: BFLFluxProGenerateResponse = await operation.execute() - return await _poll_until_generated( - response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id - ) - - -async def _poll_until_generated( - polling_url: str, timeout=360, node_id: Union[str, None] = None -): - # used bfl-comfy-nodes to verify code implementation: - # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main - start_time = time.time() - retries_404 = 0 - max_retries_404 = 5 - retry_404_seconds = 2 - retry_202_seconds = 2 - retry_pending_seconds = 1 - - async with aiohttp.ClientSession() as session: - # NOTE: should True loop be replaced with checking if workflow has been interrupted? - while True: - if node_id: - time_elapsed = time.time() - start_time - PromptServer.instance.send_progress_text( - f"Generating ({time_elapsed:.0f}s)", node_id - ) - - async with session.get(polling_url) as response: - if response.status == 200: - result = await response.json() - if result["status"] == BFLStatus.ready: - img_url = result["result"]["sample"] - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {img_url}", node_id - ) - async with session.get(img_url) as img_resp: - return process_image_response(await img_resp.content.read()) - elif result["status"] in [ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - ]: - status = result["status"] - raise Exception( - f"BFL API did not return an image due to: {status}." - ) - elif result["status"] == BFLStatus.error: - raise Exception(f"BFL API encountered an error: {result}.") - elif result["status"] == BFLStatus.pending: - await asyncio.sleep(retry_pending_seconds) - continue - elif response.status == 404: - if retries_404 < max_retries_404: - retries_404 += 1 - await asyncio.sleep(retry_404_seconds) - continue - raise Exception( - f"BFL API could not find task after {max_retries_404} tries." - ) - elif response.status == 202: - await asyncio.sleep(retry_202_seconds) - elif time.time() - start_time > timeout: - raise Exception( - f"BFL API experienced a timeout; could not return request under {timeout} seconds." - ) - else: - raise Exception(f"BFL API encountered an error: {response.json()}") - -def convert_image_to_base64(image: torch.Tensor): - scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) - # remove batch dimension if present - if len(scaled_image.shape) > 3: - scaled_image = scaled_image[0] - image_np = (scaled_image.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format="PNG") - return base64.b64encode(img_byte_arr.getvalue()).decode() - - class FluxProUltraImageNode(IO.ComfyNode): """ Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. @@ -158,7 +65,9 @@ def define_schema(cls) -> IO.Schema: IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "seed", @@ -220,22 +129,19 @@ async def execute( cls, prompt: str, aspect_ratio: str, - prompt_upsampling=False, - raw=False, - seed=0, - image_prompt=None, - image_prompt_strength=0.1, + prompt_upsampling: bool = False, + raw: bool = False, + seed: int = 0, + image_prompt: Optional[torch.Tensor] = None, + image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1-ultra/generate", - method=HttpMethod.POST, - request_model=BFLFluxProUltraGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxProUltraGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxProUltraGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, seed=seed, @@ -247,22 +153,26 @@ async def execute( maximum_ratio_str=cls.MAXIMUM_RATIO_STR, ), raw=raw, - image_prompt=( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ), - image_prompt_strength=( - None if image_prompt is None else round(image_prompt_strength, 2) - ), + image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), + image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextProImageNode(IO.ComfyNode): @@ -347,7 +257,7 @@ async def execute( aspect_ratio: str, guidance: float, steps: int, - input_image: Optional[torch.Tensor]=None, + input_image: Optional[torch.Tensor] = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -360,33 +270,36 @@ async def execute( ) if input_image is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=cls.BFL_PATH, - method=HttpMethod.POST, - request_model=BFLFluxKontextProGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxKontextProGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=cls.BFL_PATH, method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxKontextProGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, guidance=round(guidance, 1), steps=steps, seed=seed, aspect_ratio=aspect_ratio, - input_image=( - input_image - if input_image is None - else convert_image_to_base64(input_image) - ) + input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextMaxImageNode(FluxKontextProImageNode): @@ -422,7 +335,9 @@ def define_schema(cls) -> IO.Schema: IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "width", @@ -481,20 +396,15 @@ async def execute( image_prompt=None, # image_prompt_strength=0.1, ) -> IO.NodeOutput: - image_prompt = ( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( + image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt) + initial_response = await sync_op( + cls, + ApiEndpoint( path="/proxy/bfl/flux-pro-1.1/generate", - method=HttpMethod.POST, - request_model=BFLFluxProGenerateRequest, - response_model=BFLFluxProGenerateResponse, + method="POST", ), - request=BFLFluxProGenerateRequest( + response_model=BFLFluxProGenerateResponse, + data=BFLFluxProGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, width=width, @@ -502,13 +412,23 @@ async def execute( seed=seed, image_prompt=image_prompt, ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxProExpandNode(IO.ComfyNode): @@ -534,7 +454,9 @@ def define_schema(cls) -> IO.Schema: IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "top", @@ -610,16 +532,11 @@ async def execute( guidance: float, seed=0, ) -> IO.NodeOutput: - image = convert_image_to_base64(image) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-expand/generate", - method=HttpMethod.POST, - request_model=BFLFluxExpandImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxExpandImageRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxExpandImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, top=top, @@ -629,16 +546,25 @@ async def execute( steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxProFillNode(IO.ComfyNode): @@ -665,7 +591,9 @@ def define_schema(cls) -> IO.Schema: IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Float.Input( "guidance", @@ -712,272 +640,37 @@ async def execute( ) -> IO.NodeOutput: # prepare mask mask = resize_mask_to_image(mask, image) - mask = convert_image_to_base64(convert_mask_to_image(mask)) - # make sure image will have alpha channel removed - image = convert_image_to_base64(image[:, :, :, :3]) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-fill/generate", - method=HttpMethod.POST, - request_model=BFLFluxFillImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxFillImageRequest( + mask = tensor_to_base64_string(convert_mask_to_image(mask)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxFillImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed mask=mask, ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - - -class FluxProCannyNode(IO.ComfyNode): - """ - Generate image using a control image (canny). - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProCannyNode", - display_name="Flux.1 Canny Control Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.Image.Input("control_image"), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Float.Input( - "canny_low_threshold", - default=0.1, - min=0.01, - max=0.99, - step=0.01, - tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True", - ), - IO.Float.Input( - "canny_high_threshold", - default=0.4, - min=0.01, - max=0.99, - step=0.01, - tooltip="High threshold for Canny edge detection; ignored if skip_processing is True", - ), - IO.Boolean.Input( - "skip_preprocessing", - default=False, - tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", - ), - IO.Float.Input( - "guidance", - default=30, - min=1, - max=100, - tooltip="Guidance strength for the image generation process", - ), - IO.Int.Input( - "steps", - default=50, - min=15, - max=50, - tooltip="Number of steps for the image generation process", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - control_image: torch.Tensor, - prompt: str, - prompt_upsampling: bool, - canny_low_threshold: float, - canny_high_threshold: float, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - ) -> IO.NodeOutput: - control_image = convert_image_to_base64(control_image[:, :, :, :3]) - preprocessed_image = None - - # scale canny threshold between 0-500, to match BFL's API - def scale_value(value: float, min_val=0, max_val=500): - return min_val + value * (max_val - min_val) - canny_low_threshold = int(round(scale_value(canny_low_threshold))) - canny_high_threshold = int(round(scale_value(canny_high_threshold))) - - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - canny_low_threshold = None - canny_high_threshold = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-canny/generate", - method=HttpMethod.POST, - request_model=BFLFluxCannyImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxCannyImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - control_image=control_image, - canny_low_threshold=canny_low_threshold, - canny_high_threshold=canny_high_threshold, - preprocessed_image=preprocessed_image, - ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - - -class FluxProDepthNode(IO.ComfyNode): - """ - Generate image using a control image (depth). - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProDepthNode", - display_name="Flux.1 Depth Control Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.Image.Input("control_image"), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Boolean.Input( - "skip_preprocessing", - default=False, - tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", - ), - IO.Float.Input( - "guidance", - default=15, - min=1, - max=100, - tooltip="Guidance strength for the image generation process", - ), - IO.Int.Input( - "steps", - default=50, - min=15, - max=50, - tooltip="Number of steps for the image generation process", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - control_image: torch.Tensor, - prompt: str, - prompt_upsampling: bool, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - ) -> IO.NodeOutput: - control_image = convert_image_to_base64(control_image[:,:,:,:3]) - preprocessed_image = None - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-depth/generate", - method=HttpMethod.POST, - request_model=BFLFluxDepthImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxDepthImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - control_image=control_image, - preprocessed_image=preprocessed_image, - ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, + queued_statuses=[], ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class BFLExtension(ComfyExtension): @@ -990,8 +683,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, - FluxProCannyNode, - FluxProDepthNode, ] diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f3d3f8d3eeab..534af380debb 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,35 +1,27 @@ import logging import math from enum import Enum -from typing import Literal, Optional, Type, Union -from typing_extensions import override +from typing import Literal, Optional, Union import torch from pydantic import BaseModel, Field +from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import ( - validate_image_aspect_ratio_range, - get_number_of_images, - validate_image_dimensions, -) -from comfy_api_nodes.apis.client import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( ApiEndpoint, - EmptyRequest, - HttpMethod, - SynchronousOperation, - PollingOperation, - T, -) -from comfy_api_nodes.apinode_utils import ( download_url_to_image_tensor, download_url_to_video_output, + get_number_of_images, + image_tensor_pair_to_batch, + poll_op, + sync_op, upload_images_to_comfyapi, + validate_image_aspect_ratio_range, + validate_image_dimensions, validate_string, - image_tensor_pair_to_batch, ) - BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" # Long-running tasks endpoints(e.g., video) @@ -46,13 +38,14 @@ class Image2ImageModelName(str, Enum): class Text2VideoModelName(str, Enum): - seedance_1_pro = "seedance-1-0-pro-250528" + seedance_1_pro = "seedance-1-0-pro-250528" seedance_1_lite = "seedance-1-0-lite-t2v-250428" class Image2VideoModelName(str, Enum): """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" - seedance_1_pro = "seedance-1-0-pro-250528" + + seedance_1_pro = "seedance-1-0-pro-250528" seedance_1_lite = "seedance-1-0-lite-i2v-250428" @@ -208,35 +201,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N return None -async def poll_until_finished( - auth_kwargs: dict[str, str], - task_id: str, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - completed_statuses=[ - "succeeded", - ], - failed_statuses=[ - "cancelled", - "failed", - ], - status_extractor=lambda response: response.status, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - ).execute() - - class ByteDanceImageNode(IO.ComfyNode): @classmethod @@ -303,7 +267,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image", + tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), ], @@ -341,8 +305,7 @@ async def execute( w, h = width, height if not (512 <= w <= 2048) or not (512 <= h <= 2048): raise ValueError( - f"Custom size out of range: {w}x{h}. " - "Both width and height must be between 512 and 2048 pixels." + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels." ) payload = Text2ImageTaskCreationRequest( @@ -353,20 +316,12 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Text2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -420,7 +375,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image", + tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), ], @@ -449,16 +404,7 @@ async def execute( if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - source_url = (await upload_images_to_comfyapi( - image, - max_images=1, - mime_type="image/png", - auth_kwargs=auth_kwargs, - ))[0] + source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] payload = Image2ImageTaskCreationRequest( model=model, prompt=prompt, @@ -467,16 +413,12 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Image2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -504,7 +446,7 @@ def define_schema(cls): IO.Image.Input( "image", tooltip="Input image(s) for image-to-image generation. " - "List of 1-10 images for single or multi-reference generation.", + "List of 1-10 images for single or multi-reference generation.", optional=True, ), IO.Combo.Input( @@ -534,9 +476,9 @@ def define_schema(cls): "sequential_image_generation", options=["disabled", "auto"], tooltip="Group image generation mode. " - "'disabled' generates a single image. " - "'auto' lets the model decide whether to generate multiple related images " - "(e.g., story scenes, character variations).", + "'disabled' generates a single image. " + "'auto' lets the model decide whether to generate multiple related images " + "(e.g., story scenes, character variations).", optional=True, ), IO.Int.Input( @@ -547,7 +489,7 @@ def define_schema(cls): step=1, display_mode=IO.NumberDisplay.number, tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " - "Total images (input + generated) cannot exceed 15.", + "Total images (input + generated) cannot exceed 15.", optional=True, ), IO.Int.Input( @@ -564,7 +506,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image.", + tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, ), IO.Boolean.Input( @@ -611,8 +553,7 @@ async def execute( w, h = width, height if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): raise ValueError( - f"Custom size out of range: {w}x{h}. " - "Both width and height must be between 1024 and 4096 pixels." + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." ) n_input_images = get_number_of_images(image) if image is not None else 0 if n_input_images > 10: @@ -621,41 +562,31 @@ async def execute( raise ValueError( "The maximum number of generated images plus the number of reference images cannot exceed 15." ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } reference_images_urls = [] if n_input_images: for i in image: validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) - reference_images_urls = (await upload_images_to_comfyapi( + reference_images_urls = await upload_images_to_comfyapi( + cls, image, max_images=n_input_images, mime_type="image/png", - auth_kwargs=auth_kwargs, - )) - payload = Seedream4TaskCreationRequest( - model=model, - prompt=prompt, - image=reference_images_urls, - size=f"{w}x{h}", - seed=seed, - sequential_image_generation=sequential_image_generation, - sequential_image_generation_options=Seedream4Options(max_images=max_images), - watermark=watermark, - ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Seedream4TaskCreationRequest, - response_model=ImageTaskCreationResponse, + ) + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + response_model=ImageTaskCreationResponse, + data=Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - + ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] @@ -719,13 +650,13 @@ def define_schema(cls): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -764,19 +695,9 @@ async def execute( f"--camerafixed {str(camera_fixed).lower()} " f"--watermark {str(watermark).lower()}" ) - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } return await process_video_task( - request_model=Text2VideoTaskCreationRequest, - payload=Text2VideoTaskCreationRequest( - model=model, - content=[TaskTextContent(text=prompt)], - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -840,13 +761,13 @@ def define_schema(cls): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -879,13 +800,7 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0] - + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -897,13 +812,11 @@ async def execute( ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -971,13 +884,13 @@ def define_schema(cls): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -1012,16 +925,11 @@ async def execute( validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, image_tensor_pair_to_batch(first_frame, last_frame), max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) prompt = ( @@ -1035,7 +943,7 @@ async def execute( ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[ @@ -1044,8 +952,6 @@ async def execute( TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -1108,7 +1014,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -1141,15 +1047,7 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_urls = await upload_images_to_comfyapi( - images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs - ) - + image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -1160,42 +1058,32 @@ async def execute( ) x = [ TaskTextContent(text=prompt), - *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] + *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls], ] return await process_video_task( - request_model=Image2VideoTaskCreationRequest, - payload=Image2VideoTaskCreationRequest( - model=model, - content=x, - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Image2VideoTaskCreationRequest(model=model, content=x), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) async def process_video_task( - request_model: Type[T], + cls: type[IO.ComfyNode], payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], - auth_kwargs: dict, - node_id: str, estimated_duration: Optional[int], ) -> IO.NodeOutput: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_TASK_ENDPOINT, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - response = await poll_until_finished( - auth_kwargs, - initial_response.id, + initial_response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=payload, + response_model=TaskCreationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + status_extractor=lambda r: r.status, estimated_duration=estimated_duration, - node_id=node_id, + response_model=TaskStatusResponse, ) return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) @@ -1221,5 +1109,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: ByteDanceImageReferenceNode, ] + async def comfy_entrypoint() -> ByteDanceExtension: return ByteDanceExtension() diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index c1941cbe929f..ca11b67ed192 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -33,12 +33,9 @@ SynchronousOperation, ) from comfy_api_nodes.apinode_utils import ( - validate_string, - audio_to_base64_string, video_to_base64_string, - tensor_to_base64_string, - bytesio_to_image_tensor, ) +from comfy_api_nodes.util import validate_string, tensor_to_base64_string, bytesio_to_image_tensor, audio_to_base64_string from comfy_api.util import VideoContainer, VideoCodec diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 67c8307c55ff..eea65c9acf97 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -5,8 +5,7 @@ """ from __future__ import annotations -from typing import Optional, TypeVar, Any -from collections.abc import Callable +from typing import Optional, TypeVar import math import logging @@ -15,7 +14,6 @@ import torch from comfy_api_nodes.apis import ( - KlingTaskStatus, KlingCameraControl, KlingCameraConfig, KlingCameraControlType, @@ -52,26 +50,20 @@ KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - tensor_to_base64_string, - download_url_to_video_output, - upload_video_to_comfyapi, - upload_audio_to_comfyapi, - download_url_to_image_tensor, - validate_string, -) -from comfy_api_nodes.util.validation_utils import ( +from comfy_api_nodes.util import ( validate_image_dimensions, validate_image_aspect_ratio, validate_video_dimensions, validate_video_duration, + tensor_to_base64_string, + validate_string, + upload_audio_to_comfyapi, + download_url_to_image_tensor, + upload_video_to_comfyapi, + download_url_to_video_output, + sync_op, + ApiEndpoint, + poll_op, ) from comfy_api.input_impl import VideoFromFile from comfy_api.input.basic_types import AudioInput @@ -214,34 +206,6 @@ } -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - KlingTaskStatus.succeed.value, - ], - failed_statuses=[KlingTaskStatus.failed.value], - status_extractor=lambda response: ( - response.data.task_status.value - if response.data and response.data.task_status - else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - poll_interval=16.0, - max_poll_attempts=256, - ).execute() - - def is_valid_camera_control_configs(configs: list[float]) -> bool: """Verifies that at least one camera control configuration is non-zero.""" return any(not math.isclose(value, 0.0) for value in configs) @@ -377,8 +341,7 @@ async def image_result_to_node_output( async def execute_text2video( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], prompt: str, negative_prompt: str, cfg_scale: float, @@ -389,14 +352,11 @@ async def execute_text2video( camera_control: Optional[KlingCameraControl] = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingText2VideoRequest, - response_model=KlingText2VideoResponse, - ), - request=KlingText2VideoRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), + response_model=KlingText2VideoResponse, + data=KlingText2VideoRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, duration=KlingVideoGenDuration(duration), @@ -406,24 +366,17 @@ async def execute_text2video( aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), camera_control=camera_control, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingText2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"), + response_model=KlingText2VideoResponse, estimated_duration=AVERAGE_DURATION_T2V, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -432,8 +385,7 @@ async def execute_text2video( async def execute_image2video( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], start_frame: torch.Tensor, prompt: str, negative_prompt: str, @@ -455,14 +407,11 @@ async def execute_image2video( if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: model_mode = "pro" # October 5: currently "std" mode is not supported for this model - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - request=KlingImage2VideoRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=KlingImage2VideoResponse, + data=KlingImage2VideoRequest( model_name=KlingVideoGenModelName(model_name), image=tensor_to_base64_string(start_frame), image_tail=( @@ -477,24 +426,17 @@ async def execute_image2video( duration=KlingVideoGenDuration(duration), camera_control=camera_control, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), + response_model=KlingImage2VideoResponse, estimated_duration=AVERAGE_DURATION_I2V, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -503,8 +445,7 @@ async def execute_image2video( async def execute_video_effect( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], dual_character: bool, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, model_name: str, @@ -530,35 +471,25 @@ async def execute_video_effect( duration=duration, ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EFFECTS, - method=HttpMethod.POST, - request_model=KlingVideoEffectsRequest, - response_model=KlingVideoEffectsResponse, - ), - request=KlingVideoEffectsRequest( + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"), + response_model=KlingVideoEffectsResponse, + data=KlingVideoEffectsRequest( effect_scene=effect_scene, input=request_input_field, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIDEO_EFFECTS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoEffectsResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"), + response_model=KlingVideoEffectsResponse, estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -567,8 +498,7 @@ async def execute_video_effect( async def execute_lipsync( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], video: VideoInput, audio: Optional[AudioInput] = None, voice_language: Optional[str] = None, @@ -583,24 +513,21 @@ async def execute_lipsync( validate_video_duration(video, 2, 10) # Upload video to Comfy API and get download URL - video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs) + video_url = await upload_video_to_comfyapi(cls, video) logging.info("Uploaded video to Comfy API. URL: %s", video_url) # Upload the audio file to Comfy API and get download URL if audio: - audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs) + audio_url = await upload_audio_to_comfyapi(cls, audio) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_LIP_SYNC, - method=HttpMethod.POST, - request_model=KlingLipSyncRequest, - response_model=KlingLipSyncResponse, - ), - request=KlingLipSyncRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(PATH_LIP_SYNC, "POST"), + response_model=KlingLipSyncResponse, + data=KlingLipSyncRequest( input=KlingLipSyncInputObject( video_url=video_url, mode=model_mode, @@ -612,24 +539,17 @@ async def execute_lipsync( voice_id=voice_id, ), ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_LIP_SYNC}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingLipSyncResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"), + response_model=KlingLipSyncResponse, estimated_duration=AVERAGE_DURATION_LIP_SYNC, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -807,11 +727,7 @@ async def execute( ) -> IO.NodeOutput: model_mode, duration, model_name = MODE_TEXT2VIDEO[mode] return await execute_text2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt=prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale, @@ -872,11 +788,7 @@ async def execute( camera_control: Optional[KlingCameraControl] = None, ) -> IO.NodeOutput: return await execute_text2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, model_mode=KlingVideoGenMode.std, @@ -944,11 +856,7 @@ async def execute( end_frame: Optional[torch.Tensor] = None, ) -> IO.NodeOutput: return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, start_frame=start_frame, prompt=prompt, negative_prompt=negative_prompt, @@ -1017,11 +925,7 @@ async def execute( camera_control: KlingCameraControl, ) -> IO.NodeOutput: return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, @@ -1097,11 +1001,7 @@ async def execute( ) -> IO.NodeOutput: mode, duration, model_name = MODE_START_END_FRAME[mode] return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, @@ -1162,41 +1062,27 @@ async def execute( video_id: str, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EXTEND, - method=HttpMethod.POST, - request_model=KlingVideoExtendRequest, - response_model=KlingVideoExtendResponse, - ), - request=KlingVideoExtendRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"), + response_model=KlingVideoExtendResponse, + data=KlingVideoExtendRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, cfg_scale=cfg_scale, video_id=video_id, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_VIDEO_EXTEND}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoExtendResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"), + response_model=KlingVideoExtendResponse, estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -1259,11 +1145,7 @@ async def execute( duration: KlingVideoGenDuration, ) -> IO.NodeOutput: video, _, duration = await execute_video_effect( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, dual_character=True, effect_scene=effect_scene, model_name=model_name, @@ -1324,11 +1206,7 @@ async def execute( return IO.NodeOutput( *( await execute_video_effect( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, dual_character=False, effect_scene=effect_scene, model_name=model_name, @@ -1379,11 +1257,7 @@ async def execute( voice_language: str, ) -> IO.NodeOutput: return await execute_lipsync( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, video=video, audio=audio, voice_language=voice_language, @@ -1445,11 +1319,7 @@ async def execute( ) -> IO.NodeOutput: voice_id, voice_language = VOICES_CONFIG[voice] return await execute_lipsync( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, video=video, text=text, voice_language=voice_language, @@ -1496,40 +1366,26 @@ async def execute( cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIRTUAL_TRY_ON, - method=HttpMethod.POST, - request_model=KlingVirtualTryOnRequest, - response_model=KlingVirtualTryOnResponse, - ), - request=KlingVirtualTryOnRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"), + response_model=KlingVirtualTryOnResponse, + data=KlingVirtualTryOnRequest( human_image=tensor_to_base64_string(human_image), cloth_image=tensor_to_base64_string(cloth_image), model_name=model_name, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVirtualTryOnResponse, - ), - result_url_extractor=get_images_urls_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"), + response_model=KlingVirtualTryOnResponse, estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) @@ -1625,18 +1481,11 @@ async def execute( else: image = tensor_to_base64_string(image) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_GENERATIONS, - method=HttpMethod.POST, - request_model=KlingImageGenerationsRequest, - response_model=KlingImageGenerationsResponse, - ), - request=KlingImageGenerationsRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"), + response_model=KlingImageGenerationsResponse, + data=KlingImageGenerationsRequest( model_name=model_name, prompt=prompt, negative_prompt=negative_prompt, @@ -1647,24 +1496,17 @@ async def execute( n=n, aspect_ratio=aspect_ratio, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingImageGenerationsResponse, - ), - result_url_extractor=get_images_urls_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"), + response_model=KlingImageGenerationsResponse, estimated_duration=AVERAGE_DURATION_IMAGE_GEN, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 610d95a77b9d..e74441e5ef5f 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -35,9 +35,9 @@ from comfy_api_nodes.apinode_utils import ( upload_images_to_comfyapi, process_image_response, - validate_string, ) from server import PromptServer +from comfy_api_nodes.util import validate_string import aiohttp import torch diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 23be1ae65ad8..e3722e79b715 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -24,8 +24,8 @@ from comfy_api_nodes.apinode_utils import ( download_url_to_bytesio, upload_images_to_comfyapi, - validate_string, ) +from comfy_api_nodes.util import validate_string from server import PromptServer diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 7566188dd86c..7c31d95b300a 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,35 +1,31 @@ import logging -from typing import Any, Callable, Optional, TypeVar +from typing import Optional + import torch from typing_extensions import override -from comfy_api_nodes.util.validation_utils import validate_image_dimensions +from comfy_api.input import VideoInput +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis import ( - MoonvalleyTextToVideoRequest, + MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, + MoonvalleyTextToVideoRequest, MoonvalleyVideoToVideoInferenceParams, MoonvalleyVideoToVideoRequest, - MoonvalleyPromptResponse, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( download_url_to_video_output, + poll_op, + sync_op, + trim_video, upload_images_to_comfyapi, upload_video_to_comfyapi, validate_container_format_is_mp4, + validate_image_dimensions, + validate_string, ) -from comfy_api.input import VideoInput -from comfy_api.latest import ComfyExtension, InputImpl, IO -import av -import io - API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" @@ -51,13 +47,6 @@ MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 -R = TypeVar("R") - - -class MoonvalleyApiError(Exception): - """Base exception for Moonvalley API errors.""" - - pass def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: @@ -69,64 +58,7 @@ def validate_task_creation_response(response) -> None: if not is_valid_task_creation_response(response): error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" logging.error(error_msg) - raise MoonvalleyApiError(error_msg) - - -def get_video_from_response(response): - video = response.output_url - logging.info( - "Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video - ) - return video - - -def get_video_url_from_response(response) -> Optional[str]: - """Returns the first video url from the Moonvalley video generation task result. - Will not raise an error if the response is not valid. - """ - if response: - return str(get_video_from_response(response)) - else: - return None - - -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - "completed", - ], - max_poll_attempts=240, # 64 minutes with 16s interval - poll_interval=16.0, - failed_statuses=["error"], - status_extractor=lambda response: ( - response.status if response and response.status else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - node_id=node_id, - ).execute() - - -def validate_prompts( - prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH -): - """Verifies that the prompt isn't empty and that neither prompt is too long.""" - if not prompt: - raise ValueError("Positive prompt is empty") - if len(prompt) > max_length: - raise ValueError(f"Positive prompt is too long: {len(prompt)} characters") - if negative_prompt and len(negative_prompt) > max_length: - raise ValueError( - f"Negative prompt is too long: {len(negative_prompt)} characters" - ) - return True + raise RuntimeError(error_msg) def validate_video_to_video_input(video: VideoInput) -> VideoInput: @@ -170,12 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None: } if (width, height) not in supported_resolutions: - supported_list = ", ".join( - [f"{w}x{h}" for w, h in sorted(supported_resolutions)] - ) - raise ValueError( - f"Resolution {width}x{height} not supported. Supported: {supported_list}" - ) + supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)]) + raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") def _validate_and_trim_duration(video: VideoInput) -> VideoInput: @@ -188,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput: def _validate_minimum_duration(duration: float) -> None: """Ensures video is at least 5 seconds long.""" if duration < 5: - raise MoonvalleyApiError("Input video must be at least 5 seconds long.") + raise ValueError("Input video must be at least 5 seconds long.") def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: @@ -198,123 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: return video -def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: - """ - Returns a new VideoInput object trimmed from the beginning to the specified duration, - using av to avoid loading entire video into memory. - - Args: - video: Input video to trim - duration_sec: Duration in seconds to keep from the beginning - - Returns: - VideoFromFile object that owns the output buffer - """ - output_buffer = io.BytesIO() - - input_container = None - output_container = None - - try: - # Get the stream source - this avoids loading entire video into memory - # when the source is already a file path - input_source = video.get_stream_source() - - # Open containers - input_container = av.open(input_source, mode="r") - output_container = av.open(output_buffer, mode="w", format="mp4") - - # Set up output streams for re-encoding - video_stream = None - audio_stream = None - - for stream in input_container.streams: - logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) - if isinstance(stream, av.VideoStream): - # Create output video stream with same parameters - video_stream = output_container.add_stream( - "h264", rate=stream.average_rate - ) - video_stream.width = stream.width - video_stream.height = stream.height - video_stream.pix_fmt = "yuv420p" - logging.info( - "Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate - ) - elif isinstance(stream, av.AudioStream): - # Create output audio stream with same parameters - audio_stream = output_container.add_stream( - "aac", rate=stream.sample_rate - ) - audio_stream.sample_rate = stream.sample_rate - audio_stream.layout = stream.layout - logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) - - # Calculate target frame count that's divisible by 16 - fps = input_container.streams.video[0].average_rate - estimated_frames = int(duration_sec * fps) - target_frames = ( - estimated_frames // 16 - ) * 16 # Round down to nearest multiple of 16 - - if target_frames == 0: - raise ValueError("Video too short: need at least 16 frames for Moonvalley") - - frame_count = 0 - audio_frame_count = 0 - - # Decode and re-encode video frames - if video_stream: - for frame in input_container.decode(video=0): - if frame_count >= target_frames: - break - - # Re-encode frame - for packet in video_stream.encode(frame): - output_container.mux(packet) - frame_count += 1 - - # Flush encoder - for packet in video_stream.encode(): - output_container.mux(packet) - - logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) - - # Decode and re-encode audio frames - if audio_stream: - input_container.seek(0) # Reset to beginning for audio - for frame in input_container.decode(audio=0): - if frame.time >= duration_sec: - break - - # Re-encode frame - for packet in audio_stream.encode(frame): - output_container.mux(packet) - audio_frame_count += 1 - - # Flush encoder - for packet in audio_stream.encode(): - output_container.mux(packet) - - logging.info("Encoded %s audio frames", audio_frame_count) - - # Close containers - output_container.close() - input_container.close() - - # Return as VideoFromFile using the buffer - output_buffer.seek(0) - return InputImpl.VideoFromFile(output_buffer) - - except Exception as e: - # Clean up on error - if input_container is not None: - input_container.close() - if output_container is not None: - output_container.close() - raise RuntimeError(f"Failed to trim video: {str(e)}") from e - - def parse_width_height_from_res(resolution: str): # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict res_map = { @@ -338,19 +149,14 @@ def parse_control_parameter(value): return control_map.get(value, control_map["Motion Transfer"]) -async def get_response( - task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None -) -> MoonvalleyPromptResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{API_PROMPTS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MoonvalleyPromptResponse, - ), - result_url_extractor=get_video_url_from_response, - node_id=node_id, +async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse: + return await poll_op( + cls, + ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), + response_model=MoonvalleyPromptResponse, + status_extractor=lambda r: (r.status if r and r.status else None), + poll_interval=16.0, + max_poll_attempts=240, ) @@ -444,14 +250,10 @@ async def execute( steps: int, ) -> IO.NodeOutput: validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) - validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, steps=steps, @@ -464,33 +266,17 @@ async def execute( # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - - image_url = ( - await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=auth, mime_type=mime_type - ) - )[0] - - request = MoonvalleyTextToVideoRequest( - image_url=image_url, prompt_text=prompt, inference_params=inference_params - ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_IMG2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyTextToVideoRequest, - response_model=MoonvalleyPromptResponse, + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest( + image_url=image_url, prompt_text=prompt, inference_params=inference_params ), - request=request, - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) + final_response = await get_response(cls, task_creation_response.id) video = await download_url_to_video_output(final_response.output_url) return IO.NodeOutput(video) @@ -582,15 +368,10 @@ async def execute( steps=33, prompt_adherence=4.5, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) - - validate_prompts(prompt, negative_prompt) + video_url = await upload_video_to_comfyapi(cls, validated_video) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) # Only include motion_intensity for Motion Transfer control_params = {} @@ -605,35 +386,20 @@ async def execute( guidance_scale=prompt_adherence, ) - control = parse_control_parameter(control_type) - - request = MoonvalleyVideoToVideoRequest( - control_type=control, - video_url=video_url, - prompt_text=prompt, - inference_params=inference_params, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_VIDEO2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyVideoToVideoRequest, - response_model=MoonvalleyPromptResponse, + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyVideoToVideoRequest( + control_type=parse_control_parameter(control_type), + video_url=video_url, + prompt_text=prompt, + inference_params=inference_params, ), - request=request, - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) - - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) class MoonvalleyTxt2VideoNode(IO.ComfyNode): @@ -720,14 +486,10 @@ async def execute( seed: int, steps: int, ) -> IO.NodeOutput: - validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, steps=steps, @@ -737,30 +499,16 @@ async def execute( width=width_height["width"], height=width_height["height"], ) - request = MoonvalleyTextToVideoRequest( - prompt_text=prompt, inference_params=inference_params - ) - init_op = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_TXT2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyTextToVideoRequest, - response_model=MoonvalleyPromptResponse, - ), - request=request, - auth_kwargs=auth, + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params), ) - task_creation_response = await init_op.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) - - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) class MoonvalleyExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index e3b81de7599e..c467e840cf65 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -43,13 +43,11 @@ ) from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, validate_and_cast_response, - validate_string, - tensor_to_base64_string, text_filepath_to_data_uri, ) from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy_api_nodes.util import downscale_image_tensor, validate_string, tensor_to_base64_string RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 27cb0067b008..5bb406a3bb77 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -14,11 +14,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput -from comfy_api_nodes.apinode_utils import ( - download_url_to_video_output, - tensor_to_bytesio, - validate_string, -) from comfy_api_nodes.apis import pika_defs from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -27,6 +22,7 @@ PollingOperation, SynchronousOperation, ) +from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio R = TypeVar("R") diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 438a7f80b1e9..b2b841be88ff 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -24,10 +24,7 @@ PollingOperation, EmptyRequest, ) -from comfy_api_nodes.apinode_utils import ( - tensor_to_bytesio, - validate_string, -) +from comfy_api_nodes.util import validate_string, tensor_to_bytesio from comfy_api.input_impl import VideoFromFile from comfy_api.latest import ComfyExtension, IO @@ -50,7 +47,6 @@ def get_video_url_from_response( async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): # first, upload image to Pixverse and get image id to use in actual generation call - files = {"image": tensor_to_bytesio(image)} operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/pixverse/image/upload", @@ -59,16 +55,14 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): response_model=PixverseImageUploadResponse, ), request=EmptyRequest(), - files=files, + files={"image": tensor_to_bytesio(image)}, content_type="multipart/form-data", auth_kwargs=auth_kwargs, ) response_upload: PixverseImageUploadResponse = await operation.execute() if response_upload.Resp is None: - raise Exception( - f"PixVerse image upload request failed: '{response_upload.ErrMsg}'" - ) + raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") return response_upload.Resp.img_id @@ -95,7 +89,6 @@ def execute(cls, template: str) -> IO.NodeOutput: template_id = pixverse_templates.get(template, None) if template_id is None: raise Exception(f"Template '{template}' is not recognized.") - # just return the integer return IO.NodeOutput(template_id) diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 8beed5675c17..8ee7e55c4e71 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -24,12 +24,10 @@ EmptyRequest, ) from comfy_api_nodes.apinode_utils import ( - bytesio_to_image_tensor, download_url_to_bytesio, - tensor_to_bytesio, resize_mask_to_image, - validate_string, ) +from comfy_api_nodes.util import validate_string, tensor_to_bytesio, bytesio_to_image_tensor from server import PromptServer import torch diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index eb03a897dece..0543d1d0e27c 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -11,7 +11,7 @@ """ -from typing import Union, Optional, Any +from typing import Union, Optional from typing_extensions import override from enum import Enum @@ -21,7 +21,6 @@ RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, - RunwayTaskStatusEnum as TaskStatus, RunwayModelEnum as Model, RunwayDurationEnum as Duration, RunwayAspectRatioEnum as AspectRatio, @@ -33,23 +32,20 @@ ReferenceImage, RunwayTextToImageAspectRatioEnum, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - upload_images_to_comfyapi, - download_url_to_video_output, +from comfy_api_nodes.util import ( image_tensor_pair_to_batch, validate_string, + validate_image_dimensions, + validate_image_aspect_ratio, + upload_images_to_comfyapi, + download_url_to_video_output, download_url_to_image_tensor, + ApiEndpoint, + sync_op, + poll_op, ) from comfy_api.input_impl import VideoFromFile from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -91,31 +87,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N return None -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, TaskStatusResponse], - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - TaskStatus.SUCCEEDED.value, - ], - failed_statuses=[ - TaskStatus.FAILED.value, - TaskStatus.CANCELLED.value, - ], - status_extractor=lambda response: response.status.value, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - progress_extractor=extract_progress_from_task_status, - ).execute() - - def extract_progress_from_task_status( response: TaskStatusResponse, ) -> Union[float, None]: @@ -132,42 +103,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N async def get_response( - task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None + cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), + return await poll_op( + cls, + ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.status.value, estimated_duration=estimated_duration, - node_id=node_id, + progress_extractor=extract_progress_from_task_status, ) async def generate_video( + cls: type[IO.ComfyNode], request: RunwayImageToVideoRequest, - auth_kwargs: dict[str, str], - node_id: Optional[str] = None, estimated_duration: Optional[int] = None, ) -> VideoFromFile: - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=RunwayImageToVideoRequest, - response_model=RunwayImageToVideoResponse, - ), - request=request, - auth_kwargs=auth_kwargs, + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=RunwayImageToVideoResponse, + data=request, ) - initial_response = await initial_operation.execute() - - final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration) + final_response = await get_response(cls, initial_response.id, estimated_duration) if not final_response.output: raise RunwayApiError("Runway task succeeded but no video data found in response.") @@ -184,9 +145,9 @@ def define_schema(cls): display_name="Runway Image to Video (Gen3a Turbo)", category="api node/video/Runway", description="Generate a video from a single starting frame using Gen3a Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", inputs=[ IO.String.Input( "prompt", @@ -241,20 +202,16 @@ async def execute( validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -262,15 +219,9 @@ async def execute( duration=Duration(duration), ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, ) ) @@ -284,9 +235,9 @@ def define_schema(cls): display_name="Runway Image to Video (Gen4 Turbo)", category="api node/video/Runway", description="Generate a video from a single starting frame using Gen4 Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", inputs=[ IO.String.Input( "prompt", @@ -341,20 +292,16 @@ async def execute( validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -362,15 +309,9 @@ async def execute( duration=Duration(duration), ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_FLF_SECONDS, ) ) @@ -385,12 +326,12 @@ def define_schema(cls): display_name="Runway First-Last-Frame to Video", category="api node/video/Runway", description="Upload first and last keyframes, draft a prompt, and generate a video. " - "More complex transitions, such as cases where the Last frame is completely different " - "from the First frame, may benefit from the longer 10s duration. " - "This would give the generation more time to smoothly transition between the two inputs. " - "Before diving in, review these best practices to ensure that your input selections " - "will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", + "More complex transitions, such as cases where the Last frame is completely different " + "from the First frame, may benefit from the longer 10s duration. " + "This would give the generation more time to smoothly transition between the two inputs. " + "Before diving in, review these best practices to ensure that your input selections " + "will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", inputs=[ IO.String.Input( "prompt", @@ -452,23 +393,19 @@ async def execute( validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) download_urls = await upload_images_to_comfyapi( + cls, stacked_input_images, max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -477,17 +414,11 @@ async def execute( ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ), - RunwayPromptImageDetailedObject( - uri=str(download_urls[1]), position="last" - ), + RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"), + RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"), ] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_FLF_SECONDS, ) ) @@ -502,7 +433,7 @@ def define_schema(cls): display_name="Runway Text to Image", category="api node/image/Runway", description="Generate an image from a text prompt using Runway's Gen 4 model. " - "You can also include reference image to guide the generation.", + "You can also include reference image to guide the generation.", inputs=[ IO.String.Input( "prompt", @@ -540,49 +471,34 @@ async def execute( ) -> IO.NodeOutput: validate_string(prompt, min_length=1) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # Prepare reference images if provided reference_images = None if reference_image is not None: validate_image_dimensions(reference_image, max_width=7999, max_height=7999) validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) download_urls = await upload_images_to_comfyapi( + cls, reference_image, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) reference_images = [ReferenceImage(uri=str(download_urls[0]))] - request = RunwayTextToImageRequest( - promptText=prompt, - model=Model4.gen4_image, - ratio=ratio, - referenceImages=reference_images, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_IMAGE, - method=HttpMethod.POST, - request_model=RunwayTextToImageRequest, - response_model=RunwayTextToImageResponse, + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"), + response_model=RunwayTextToImageResponse, + data=RunwayTextToImageRequest( + promptText=prompt, + model=Model4.gen4_image, + ratio=ratio, + referenceImages=reference_images, ), - request=request, - auth_kwargs=auth_kwargs, ) - initial_response = await initial_operation.execute() - - # Poll for completion final_response = await get_response( + cls, initial_response.id, - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) if not final_response.output: @@ -601,5 +517,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: RunwayTextToImageNode, ] + async def comfy_entrypoint() -> RunwayExtension: return RunwayExtension() diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index efc95486977e..92b225d4043d 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -1,23 +1,20 @@ from typing import Optional -from typing_extensions import override import torch from pydantic import BaseModel, Field -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.util.validation_utils import get_number_of_images +from typing_extensions import override -from comfy_api_nodes.apinode_utils import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, tensor_to_bytesio, ) + class Sora2GenerationRequest(BaseModel): prompt: str = Field(...) model: str = Field(...) @@ -80,7 +77,7 @@ def define_schema(cls): control_after_generate=True, optional=True, tooltip="Seed to determine if node should re-run; " - "actual results are nondeterministic regardless of seed.", + "actual results are nondeterministic regardless of seed.", ), ], outputs=[ @@ -111,55 +108,34 @@ async def execute( if get_number_of_images(image) != 1: raise ValueError("Currently only one input image is supported.") files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")} - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - payload = Sora2GenerationRequest( - model=model, - prompt=prompt, - seconds=str(duration), - size=size, - ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/openai/v1/videos", - method=HttpMethod.POST, - request_model=Sora2GenerationRequest, - response_model=Sora2GenerationResponse + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"), + data=Sora2GenerationRequest( + model=model, + prompt=prompt, + seconds=str(duration), + size=size, ), - request=payload, files=files_input, - auth_kwargs=auth, + response_model=Sora2GenerationResponse, content_type="multipart/form-data", ) - initial_response = await initial_operation.execute() if initial_response.error: - raise Exception(initial_response.error.message) + raise Exception(initial_response.error["message"]) model_time_multiplier = 1 if model == "sora-2" else 2 - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/openai/v1/videos/{initial_response.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=Sora2GenerationResponse - ), - completed_statuses=["completed"], - failed_statuses=["failed"], + await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"), + response_model=Sora2GenerationResponse, status_extractor=lambda x: x.status, - auth_kwargs=auth, poll_interval=8.0, max_poll_attempts=160, - node_id=cls.hidden.unique_id, - estimated_duration=45 * (duration / 4) * model_time_multiplier, + estimated_duration=int(45 * (duration / 4) * model_time_multiplier), ) - await poll_operation.execute() return IO.NodeOutput( - await download_url_to_video_output( - f"/proxy/openai/v1/videos/{initial_response.id}/content", - auth_kwargs=auth, - ) + await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls), ) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 8af03cfd1247..783666ddf5fa 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -27,14 +27,14 @@ PollingOperation, EmptyRequest, ) -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( + validate_audio_duration, + validate_string, + audio_input_to_mp3, bytesio_to_image_tensor, tensor_to_bytesio, - validate_string, audio_bytes_to_audio_input, - audio_input_to_mp3, ) -from comfy_api_nodes.util.validation_utils import validate_audio_duration import torch import base64 diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index daeaa823e44e..d37e9e9b410a 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,28 +1,21 @@ -import logging import base64 -import aiohttp -import torch from io import BytesIO -from typing import Optional + from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( - VeoGenVidRequest, - VeoGenVidResponse, +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollRequest, VeoGenVidPollResponse, + VeoGenVidRequest, + VeoGenVidResponse, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, -) - -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, + download_url_to_video_output, + poll_op, + sync_op, tensor_to_base64_string, ) @@ -35,28 +28,6 @@ "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", } -def convert_image_to_base64(image: torch.Tensor): - if image is None: - return None - - scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) - return tensor_to_base64_string(scaled_image) - - -def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]: - if ( - poll_response.response - and hasattr(poll_response.response, "videos") - and poll_response.response.videos - and len(poll_response.response.videos) > 0 - ): - video = poll_response.response.videos[0] - else: - return None - if hasattr(video, "gcsUri") and video.gcsUri: - return str(video.gcsUri) - return None - class VeoVideoGenerationNode(IO.ComfyNode): """ @@ -169,18 +140,13 @@ async def execute( # Prepare the instances for the request instances = [] - instance = { - "prompt": prompt - } + instance = {"prompt": prompt} # Add image if provided if image is not None: - image_base64 = convert_image_to_base64(image) + image_base64 = tensor_to_base64_string(image) if image_base64: - instance["image"] = { - "bytesBase64Encoded": image_base64, - "mimeType": "image/png" - } + instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"} instances.append(instance) @@ -198,119 +164,77 @@ async def execute( if seed > 0: parameters["seed"] = seed # Only add generateAudio for Veo 3 models - if "veo-3.0" in model: + if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # Initial request to start video generation - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=f"/proxy/veo/{model}/generate", - method=HttpMethod.POST, - request_model=VeoGenVidRequest, - response_model=VeoGenVidResponse - ), - request=VeoGenVidRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( instances=instances, - parameters=parameters + parameters=parameters, ), - auth_kwargs=auth, ) - initial_response = await initial_operation.execute() - operation_name = initial_response.name - - logging.info("Veo generation started with operation name: %s", operation_name) - - # Define status extractor function def status_extractor(response): # Only return "completed" if the operation is done, regardless of success or failure # We'll check for errors after polling completes return "completed" if response.done else "pending" - # Define progress extractor function - def progress_extractor(response): - # Could be enhanced if the API provides progress information - return None - - # Define the polling operation - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/veo/{model}/poll", - method=HttpMethod.POST, - request_model=VeoGenVidPollRequest, - response_model=VeoGenVidPollResponse - ), - completed_statuses=["completed"], - failed_statuses=[], # No failed statuses, we'll handle errors after polling + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, status_extractor=status_extractor, - progress_extractor=progress_extractor, - request=VeoGenVidPollRequest( - operationName=operation_name + data=VeoGenVidPollRequest( + operationName=initial_response.name, ), - auth_kwargs=auth, poll_interval=5.0, - result_url_extractor=get_video_url_from_response, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) - # Execute the polling operation - poll_response = await poll_operation.execute() - # Now check for errors in the final response # Check for error in poll response - if hasattr(poll_response, 'error') and poll_response.error: - error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})" - logging.error(error_message) - raise Exception(error_message) + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") # Check for RAI filtered content - if (hasattr(poll_response.response, 'raiMediaFilteredCount') and - poll_response.response.raiMediaFilteredCount > 0): + if ( + hasattr(poll_response.response, "raiMediaFilteredCount") + and poll_response.response.raiMediaFilteredCount > 0 + ): # Extract reason message if available - if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and - poll_response.response.raiMediaFilteredReasons): + if ( + hasattr(poll_response.response, "raiMediaFilteredReasons") + and poll_response.response.raiMediaFilteredReasons + ): reason = poll_response.response.raiMediaFilteredReasons[0] error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)" else: error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)" - logging.error(error_message) raise Exception(error_message) # Extract video data - if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: + if ( + poll_response.response + and hasattr(poll_response.response, "videos") + and poll_response.response.videos + and len(poll_response.response.videos) > 0 + ): video = poll_response.response.videos[0] # Check if video is provided as base64 or URL - if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded: - # Decode base64 string to bytes - video_data = base64.b64decode(video.bytesBase64Encoded) - elif hasattr(video, 'gcsUri') and video.gcsUri: - # Download from URL - async with aiohttp.ClientSession() as session: - async with session.get(video.gcsUri) as video_response: - video_data = await video_response.content.read() - else: - raise Exception("Video returned but no data or URL was provided") - else: - raise Exception("Video generation completed but no video was returned") + if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: + return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) - if not video_data: - raise Exception("No video data was returned") + if hasattr(video, "gcsUri") and video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) - logging.info("Video generation completed successfully") - - # Convert video data to BytesIO object - video_io = BytesIO(video_data) - - # Return VideoFromFile object - return IO.NodeOutput(VideoFromFile(video_io)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") class Veo3VideoGenerationNode(VeoVideoGenerationNode): @@ -394,7 +318,10 @@ def define_schema(cls): IO.Combo.Input( "model", options=[ - "veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001" + "veo-3.1-generate", + "veo-3.1-fast-generate", + "veo-3.0-generate-001", + "veo-3.0-fast-generate-001", ], default="veo-3.0-generate-001", tooltip="Veo 3 model to use for video generation", @@ -427,5 +354,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: Veo3VideoGenerationNode, ] + async def comfy_entrypoint() -> VeoExtension: return VeoExtension() diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 639be4b2be66..0e0572f8c7c8 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,27 +1,23 @@ import logging from enum import Enum -from typing import Any, Callable, Optional, Literal, TypeVar -from typing_extensions import override +from typing import Literal, Optional, TypeVar import torch from pydantic import BaseModel, Field +from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, validate_aspect_ratio_closeness, - validate_image_dimensions, validate_image_aspect_ratio_range, - get_number_of_images, -) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, + validate_image_dimensions, ) -from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi - VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" @@ -31,8 +27,9 @@ R = TypeVar("R") + class VideoModelName(str, Enum): - vidu_q1 = 'viduq1' + vidu_q1 = "viduq1" class AspectRatio(str, Enum): @@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel): images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") -class TaskStatus(str, Enum): - created = "created" - queueing = "queueing" - processing = "processing" - success = "success" - failed = "failed" - - class TaskCreationResponse(BaseModel): task_id: str = Field(...) - state: TaskStatus = Field(...) + state: str = Field(...) created_at: str = Field(...) code: Optional[int] = Field(None, description="Error code") @@ -85,32 +74,11 @@ class TaskResult(BaseModel): class TaskStatusResponse(BaseModel): - state: TaskStatus = Field(...) + state: str = Field(...) err_code: Optional[str] = Field(None) creations: list[TaskResult] = Field(..., description="Generated results") -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[TaskStatus.success.value], - failed_statuses=[TaskStatus.failed.value], - status_extractor=lambda response: response.state.value, - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - poll_interval=16.0, - max_poll_attempts=256, - ).execute() - - def get_video_url_from_response(response) -> Optional[str]: if response.creations: return response.creations[0].url @@ -127,37 +95,27 @@ def get_video_from_response(response) -> TaskResult: async def execute_task( + cls: type[IO.ComfyNode], vidu_endpoint: str, - auth_kwargs: Optional[dict[str, str]], payload: TaskCreationRequest, estimated_duration: int, - node_id: str, ) -> R: - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=vidu_endpoint, - method=HttpMethod.POST, - request_model=TaskCreationRequest, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - if response.state == TaskStatus.failed: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), + response_model=TaskCreationResponse, + data=payload, + ) + if response.state == "failed": error_msg = f"Vidu request failed. Code: {response.code}" logging.error(error_msg) raise RuntimeError(error_msg) - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=VIDU_GET_GENERATION_STATUS % response.task_id, - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - result_url_extractor=get_video_url_from_response, + return await poll_op( + cls, + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.state.value, estimated_duration=estimated_duration, - node_id=node_id, ) @@ -258,11 +216,7 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -362,17 +316,13 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = await upload_images_to_comfyapi( + cls, image, max_images=1, mime_type="image/png", - auth_kwargs=auth, ) - results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -484,17 +434,13 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = await upload_images_to_comfyapi( + cls, images, max_images=7, mime_type="image/png", - auth_kwargs=auth, ) - results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -596,15 +542,11 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = [ - (await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0] + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -618,5 +560,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: ViduStartEndToVideoNode, ] + async def comfy_entrypoint() -> ViduExtension: return ViduExtension() diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index b089bd907b25..2aab3c2ffb5c 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -1,28 +1,24 @@ import re -from typing import Optional, Type, Union -from typing_extensions import override +from typing import Optional import torch from pydantic import BaseModel, Field -from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, - R, - T, -) -from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration +from typing_extensions import override -from comfy_api_nodes.apinode_utils import ( +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.util import ( + ApiEndpoint, + audio_to_base64_string, download_url_to_image_tensor, download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, tensor_to_base64_string, - audio_to_base64_string, + validate_audio_duration, ) + class Text2ImageInputField(BaseModel): prompt: str = Field(...) negative_prompt: Optional[str] = Field(None) @@ -146,53 +142,7 @@ class VideoTaskStatusResponse(BaseModel): request_id: str = Field(...) -RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)') - - -async def process_task( - auth_kwargs: dict[str, str], - url: str, - request_model: Type[T], - response_model: Type[R], - payload: Union[ - Text2ImageTaskCreationRequest, - Image2ImageTaskCreationRequest, - Text2VideoTaskCreationRequest, - Image2VideoTaskCreationRequest, - ], - node_id: str, - estimated_duration: int, - poll_interval: int, -) -> Type[R]: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=url, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - - if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") - - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=response_model, - ), - completed_statuses=["SUCCEEDED"], - failed_statuses=["FAILED", "CANCELED", "UNKNOWN"], - status_extractor=lambda x: x.output.task_status, - estimated_duration=estimated_duration, - poll_interval=poll_interval, - node_id=node_id, - auth_kwargs=auth_kwargs, - ).execute() +RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)") class WanTextToImageApi(IO.ComfyNode): @@ -259,7 +209,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -286,26 +236,28 @@ async def execute( prompt_extend: bool = True, watermark: bool = True, ): - payload = Text2ImageTaskCreationRequest( - model=model, - input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), - parameters=Txt2ImageParametersField( - size=f"{width}*{height}", - seed=seed, - prompt_extend=prompt_extend, - watermark=watermark, + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2ImageTaskCreationRequest( + model=model, + input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), + parameters=Txt2ImageParametersField( + size=f"{width}*{height}", + seed=seed, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", - request_model=Text2ImageTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=ImageTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=9, poll_interval=3, ) @@ -320,7 +272,7 @@ def define_schema(cls): display_name="Wan Image to Image", category="api node/image/Wan", description="Generates an image from one or two input images and a text prompt. " - "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", inputs=[ IO.Combo.Input( "model", @@ -376,7 +328,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -408,28 +360,30 @@ async def execute( raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") images = [] for i in image: - images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096)) - payload = Image2ImageTaskCreationRequest( - model=model, - input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), - parameters=Image2ImageParametersField( - # size=f"{width}*{height}", - seed=seed, - watermark=watermark, + images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2ImageTaskCreationRequest( + model=model, + input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), + parameters=Image2ImageParametersField( + # size=f"{width}*{height}", + seed=seed, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", - request_model=Image2ImageTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=ImageTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=42, - poll_interval=3, + poll_interval=4, ) return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) @@ -523,7 +477,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -557,28 +511,31 @@ async def execute( if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - payload = Text2VideoTaskCreationRequest( - model=model, - input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), - parameters=Text2VideoParametersField( - size=f"{width}*{height}", - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, + + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2VideoTaskCreationRequest( + model=model, + input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), + parameters=Text2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", - request_model=Text2VideoTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=VideoTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=120 * int(duration / 5), poll_interval=6, ) @@ -667,7 +624,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -699,35 +656,37 @@ async def execute( ): if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") - image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000) + image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) audio_url = None if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - payload = Image2VideoTaskCreationRequest( - model=model, - input=Image2VideoInputField( - prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url - ), - parameters=Image2VideoParametersField( - resolution=resolution, - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2VideoTaskCreationRequest( + model=model, + input=Image2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url + ), + parameters=Image2VideoParametersField( + resolution=resolution, + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", - request_model=Image2VideoTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=VideoTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=120 * int(duration / 5), poll_interval=6, ) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index e69de29bb2d1..c2ec391aadd4 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -0,0 +1,87 @@ +from ._helpers import get_fs_object_size +from .client import ( + ApiEndpoint, + poll_op, + poll_op_raw, + sync_op, + sync_op_raw, +) +from .conversions import ( + audio_bytes_to_audio_input, + audio_input_to_mp3, + audio_to_base64_string, + bytesio_to_image_tensor, + downscale_image_tensor, + image_tensor_pair_to_batch, + pil_to_bytesio, + tensor_to_base64_string, + tensor_to_bytesio, + tensor_to_pil, + trim_video, +) +from .download_helpers import ( + download_url_to_bytesio, + download_url_to_image_tensor, + download_url_to_video_output, +) +from .upload_helpers import ( + upload_audio_to_comfyapi, + upload_file_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, +) +from .validation_utils import ( + get_number_of_images, + validate_aspect_ratio_closeness, + validate_audio_duration, + validate_container_format_is_mp4, + validate_image_aspect_ratio, + validate_image_aspect_ratio_range, + validate_image_dimensions, + validate_string, + validate_video_dimensions, + validate_video_duration, +) + +__all__ = [ + # API client + "ApiEndpoint", + "poll_op", + "poll_op_raw", + "sync_op", + "sync_op_raw", + # Upload helpers + "upload_audio_to_comfyapi", + "upload_file_to_comfyapi", + "upload_images_to_comfyapi", + "upload_video_to_comfyapi", + # Download helpers + "download_url_to_bytesio", + "download_url_to_image_tensor", + "download_url_to_video_output", + # Conversions + "audio_bytes_to_audio_input", + "audio_input_to_mp3", + "audio_to_base64_string", + "bytesio_to_image_tensor", + "downscale_image_tensor", + "image_tensor_pair_to_batch", + "pil_to_bytesio", + "tensor_to_base64_string", + "tensor_to_bytesio", + "tensor_to_pil", + "trim_video", + # Validation utilities + "get_number_of_images", + "validate_aspect_ratio_closeness", + "validate_audio_duration", + "validate_container_format_is_mp4", + "validate_image_aspect_ratio", + "validate_image_aspect_ratio_range", + "validate_image_dimensions", + "validate_string", + "validate_video_dimensions", + "validate_video_duration", + # Misc functions + "get_fs_object_size", +] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py new file mode 100644 index 000000000000..328fe52272fd --- /dev/null +++ b/comfy_api_nodes/util/_helpers.py @@ -0,0 +1,71 @@ +import asyncio +import contextlib +import os +import time +from io import BytesIO +from typing import Callable, Optional, Union + +from comfy.cli_args import args +from comfy.model_management import processing_interrupted +from comfy_api.latest import IO + +from .common_exceptions import ProcessingInterrupted + + +def is_processing_interrupted() -> bool: + """Return True if user/runtime requested interruption.""" + return processing_interrupted() + + +def get_node_id(node_cls: type[IO.ComfyNode]) -> str: + return node_cls.hidden.unique_id + + +def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: + if node_cls.hidden.auth_token_comfy_org: + return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"} + if node_cls.hidden.api_key_comfy_org: + return {"X-API-KEY": node_cls.hidden.api_key_comfy_org} + return {} + + +def default_base_url() -> str: + return getattr(args, "comfy_api_base", "https://api.comfy.org") + + +async def sleep_with_interrupt( + seconds: float, + node_cls: Optional[type[IO.ComfyNode]], + label: Optional[str] = None, + start_ts: Optional[float] = None, + estimated_total: Optional[int] = None, + *, + display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, +): + """ + Sleep in 1s slices while: + - Checking for interruption (raises ProcessingInterrupted). + - Optionally emitting time progress via display_callback (if provided). + """ + end = time.monotonic() + seconds + while True: + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + now = time.monotonic() + if start_ts is not None and label and display_callback: + with contextlib.suppress(Exception): + display_callback(node_cls, label, int(now - start_ts), estimated_total) + if now >= end: + break + await asyncio.sleep(min(1.0, end - now)) + + +def mimetype_to_extension(mime_type: str) -> str: + """Converts a MIME type to a file extension.""" + return mime_type.split("/")[-1].lower() + + +def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: + if isinstance(path_or_object, str): + return os.path.getsize(path_or_object) + return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py new file mode 100644 index 000000000000..5833b118fdf6 --- /dev/null +++ b/comfy_api_nodes/util/client.py @@ -0,0 +1,941 @@ +import asyncio +import contextlib +import json +import logging +import socket +import time +import uuid +from dataclasses import dataclass +from enum import Enum +from io import BytesIO +from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from urllib.parse import urljoin, urlparse + +import aiohttp +from aiohttp.client_exceptions import ClientError, ContentTypeError +from pydantic import BaseModel + +from comfy import utils +from comfy_api.latest import IO +from comfy_api_nodes.apis import request_logger +from server import PromptServer + +from ._helpers import ( + default_base_url, + get_auth_header, + get_node_id, + is_processing_interrupted, + sleep_with_interrupt, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted + +M = TypeVar("M", bound=BaseModel) + + +class ApiEndpoint: + def __init__( + self, + path: str, + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", + *, + query_params: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + ): + self.path = path + self.method = method + self.query_params = query_params or {} + self.headers = headers or {} + + +@dataclass +class _RequestConfig: + node_cls: type[IO.ComfyNode] + endpoint: ApiEndpoint + timeout: float + content_type: str + data: Optional[dict[str, Any]] + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] + multipart_parser: Optional[Callable] + max_retries: int + retry_delay: float + retry_backoff: float + wait_label: str = "Waiting" + monitor_progress: bool = True + estimated_total: Optional[int] = None + final_label_on_success: Optional[str] = "Completed" + progress_origin_ts: Optional[float] = None + + +@dataclass +class _PollUIState: + started: float + status_label: str = "Queued" + is_queued: bool = True + price: Optional[float] = None + estimated_duration: Optional[int] = None + base_processing_elapsed: float = 0.0 # sum of completed active intervals + active_since: Optional[float] = None # start time of current active interval (None if queued) + + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] +FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] + + +async def sync_op( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + response_model: Type[M], + data: Optional[BaseModel] = None, + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Optional[Callable] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_duration: Optional[int] = None, + final_label_on_success: Optional[str] = "Completed", + progress_origin_ts: Optional[float] = None, + monitor_progress: bool = True, +) -> M: + raw = await sync_op_raw( + cls, + endpoint, + data=data, + files=files, + content_type=content_type, + timeout=timeout, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + estimated_duration=estimated_duration, + as_binary=False, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + monitor_progress=monitor_progress, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def poll_op( + cls: type[IO.ComfyNode], + poll_endpoint: ApiEndpoint, + *, + response_model: Type[M], + status_extractor: Callable[[M], Optional[Union[str, int]]], + progress_extractor: Optional[Callable[[M], Optional[int]]] = None, + price_extractor: Optional[Callable[[M], Optional[float]]] = None, + completed_statuses: Optional[list[Union[str, int]]] = None, + failed_statuses: Optional[list[Union[str, int]]] = None, + queued_statuses: Optional[list[Union[str, int]]] = None, + data: Optional[BaseModel] = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: Optional[int] = None, + cancel_endpoint: Optional[ApiEndpoint] = None, + cancel_timeout: float = 10.0, +) -> M: + raw = await poll_op_raw( + cls, + poll_endpoint=poll_endpoint, + status_extractor=_wrap_model_extractor(response_model, status_extractor), + progress_extractor=_wrap_model_extractor(response_model, progress_extractor), + price_extractor=_wrap_model_extractor(response_model, price_extractor), + completed_statuses=completed_statuses, + failed_statuses=failed_statuses, + queued_statuses=queued_statuses, + data=data, + poll_interval=poll_interval, + max_poll_attempts=max_poll_attempts, + timeout_per_poll=timeout_per_poll, + max_retries_per_poll=max_retries_per_poll, + retry_delay_per_poll=retry_delay_per_poll, + retry_backoff_per_poll=retry_backoff_per_poll, + estimated_duration=estimated_duration, + cancel_endpoint=cancel_endpoint, + cancel_timeout=cancel_timeout, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def sync_op_raw( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + data: Optional[Union[dict[str, Any], BaseModel]] = None, + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Optional[Callable] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_duration: Optional[int] = None, + as_binary: bool = False, + final_label_on_success: Optional[str] = "Completed", + progress_origin_ts: Optional[float] = None, + monitor_progress: bool = True, +) -> Union[dict[str, Any], bytes]: + """ + Make a single network request. + - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). + - If as_binary=True: returns bytes. + """ + if isinstance(data, BaseModel): + data = data.model_dump(exclude_none=True) + for k, v in list(data.items()): + if isinstance(v, Enum): + data[k] = v.value + cfg = _RequestConfig( + node_cls=cls, + endpoint=endpoint, + timeout=timeout, + content_type=content_type, + data=data, + files=files, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + monitor_progress=monitor_progress, + estimated_total=estimated_duration, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + ) + return await _request_base(cfg, expect_binary=as_binary) + + +async def poll_op_raw( + cls: type[IO.ComfyNode], + poll_endpoint: ApiEndpoint, + *, + status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], + progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, + price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, + completed_statuses: Optional[list[Union[str, int]]] = None, + failed_statuses: Optional[list[Union[str, int]]] = None, + queued_statuses: Optional[list[Union[str, int]]] = None, + data: Optional[Union[dict[str, Any], BaseModel]] = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: Optional[int] = None, + cancel_endpoint: Optional[ApiEndpoint] = None, + cancel_timeout: float = 10.0, +) -> dict[str, Any]: + """ + Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, + checks interruption every second, and calls Cancel endpoint (if provided) on interruption. + + Uses default complete, failed and queued states assumption. + + Returns the final JSON response from the poll endpoint. + """ + completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) + failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) + queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) + started = time.monotonic() + consumed_attempts = 0 # counts only non-queued polls + + progress_bar = utils.ProgressBar(100) if progress_extractor else None + last_progress: Optional[int] = None + + state = _PollUIState(started=started, estimated_duration=estimated_duration) + stop_ticker = asyncio.Event() + + async def _ticker(): + """Emit a UI update every second while polling is in progress.""" + try: + while not stop_ticker.is_set(): + if is_processing_interrupted(): + break + now = time.monotonic() + proc_elapsed = state.base_processing_elapsed + ( + (now - state.active_since) if state.active_since is not None else 0.0 + ) + _display_time_progress( + cls, + status=state.status_label, + elapsed_seconds=int(now - state.started), + estimated_total=state.estimated_duration, + price=state.price, + is_queued=state.is_queued, + processing_elapsed_seconds=int(proc_elapsed), + ) + await asyncio.sleep(1.0) + except Exception as exc: + logging.debug("Polling ticker exited: %s", exc) + + ticker_task = asyncio.create_task(_ticker()) + try: + while consumed_attempts < max_poll_attempts: + try: + resp_json = await sync_op_raw( + cls, + poll_endpoint, + data=data, + timeout=timeout_per_poll, + max_retries=max_retries_per_poll, + retry_delay=retry_delay_per_poll, + retry_backoff=retry_backoff_per_poll, + wait_label="Checking", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + if not isinstance(resp_json, dict): + raise Exception("Polling endpoint returned non-JSON response.") + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + + try: + status = _normalize_status_value(status_extractor(resp_json)) + except Exception as e: + logging.error("Status extraction failed: %s", e) + status = None + + if price_extractor: + new_price = price_extractor(resp_json) + if new_price is not None: + state.price = new_price + + if progress_extractor: + new_progress = progress_extractor(resp_json) + if new_progress is not None and last_progress != new_progress: + progress_bar.update_absolute(new_progress, total=100) + last_progress = new_progress + + now_ts = time.monotonic() + is_queued = status in queued_states + + if is_queued: + if state.active_since is not None: # If we just moved from active -> queued, close the active interval + state.base_processing_elapsed += now_ts - state.active_since + state.active_since = None + else: + if state.active_since is None: # If we just moved from queued -> active, open a new active interval + state.active_since = now_ts + + state.is_queued = is_queued + state.status_label = status or ("Queued" if is_queued else "Processing") + if status in completed_states: + if state.active_since is not None: + state.base_processing_elapsed += now_ts - state.active_since + state.active_since = None + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + if progress_bar and last_progress != 100: + progress_bar.update_absolute(100, total=100) + + _display_time_progress( + cls, + status=status if status else "Completed", + elapsed_seconds=int(now_ts - started), + estimated_total=estimated_duration, + price=state.price, + is_queued=False, + processing_elapsed_seconds=int(state.base_processing_elapsed), + ) + return resp_json + + if status in failed_states: + msg = f"Task failed: {json.dumps(resp_json)}" + logging.error(msg) + raise Exception(msg) + + try: + await sleep_with_interrupt(poll_interval, cls, None, None, None) + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + if not is_queued: + consumed_attempts += 1 + + raise Exception( + f"Polling timed out after {max_poll_attempts} non-queued attempts " + f"(~{int(max_poll_attempts * poll_interval)}s of active polling)." + ) + except ProcessingInterrupted: + raise + except (LocalNetworkError, ApiServerError): + raise + except Exception as e: + raise Exception(f"Polling aborted due to error: {e}") from e + finally: + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + +def _display_text( + node_cls: type[IO.ComfyNode], + text: Optional[str], + *, + status: Optional[Union[str, int]] = None, + price: Optional[float] = None, +) -> None: + display_lines: list[str] = [] + if status: + display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") + if price is not None: + display_lines.append(f"Price: ${float(price):,.4f}") + if text is not None: + display_lines.append(text) + if display_lines: + PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) + + +def _display_time_progress( + node_cls: type[IO.ComfyNode], + status: Optional[Union[str, int]], + elapsed_seconds: int, + estimated_total: Optional[int] = None, + *, + price: Optional[float] = None, + is_queued: Optional[bool] = None, + processing_elapsed_seconds: Optional[int] = None, +) -> None: + if estimated_total is not None and estimated_total > 0 and is_queued is False: + pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + remaining = max(0, int(estimated_total) - int(pe)) + time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" + else: + time_line = f"Time elapsed: {int(elapsed_seconds)}s" + _display_text(node_cls, time_line, status=status, price=price) + + +async def _diagnose_connectivity() -> dict[str, bool]: + """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" + results = { + "internet_accessible": False, + "api_accessible": False, + "is_local_issue": False, + "is_api_issue": False, + } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get("https://www.google.com") as resp: + results["internet_accessible"] = resp.status < 500 + except (ClientError, asyncio.TimeoutError, socket.gaierror): + results["is_local_issue"] = True + return results + + parsed = urlparse(default_base_url()) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + with contextlib.suppress(ClientError, asyncio.TimeoutError): + async with session.get(health_url) as resp: + results["api_accessible"] = resp.status < 500 + results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] + return results + + +def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: + """Normalize (filename, value, content_type).""" + if len(t) == 2: + return t[0], t[1], "application/octet-stream" + if len(t) == 3: + return t[0], t[1], t[2] + raise ValueError("files tuple must be (filename, file[, content_type])") + + +def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: + params = dict(endpoint_params or {}) + if method.upper() == "GET" and data: + for k, v in data.items(): + if v is not None: + params[k] = v + return params + + +def _friendly_http_message(status: int, body: Any) -> str: + if status == 401: + return "Unauthorized: Please login first to use this node." + if status == 402: + return "Payment Required: Please add credits to your account to use this node." + if status == 409: + return "There is a problem with your account. Please contact support@comfy.org." + if status == 429: + return "Rate Limit Exceeded: Please try again later." + try: + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict): + msg = err.get("message") + typ = err.get("type") + if msg and typ: + return f"API Error: {msg} (Type: {typ})" + if msg: + return f"API Error: {msg}" + return f"API Error: {json.dumps(body)}" + else: + txt = str(body) + if len(txt) <= 200: + return f"API Error (raw): {txt}" + return f"API Error (status {status})" + except Exception: + return f"HTTP {status}: Unknown error" + + +def _generate_operation_id(method: str, path: str, attempt: int) -> str: + slug = path.strip("/").replace("/", "_") or "op" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +def _snapshot_request_body_for_logging( + content_type: str, + method: str, + data: Optional[dict[str, Any]], + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], +) -> Optional[Union[dict[str, Any], str]]: + if method.upper() == "GET": + return None + if content_type == "multipart/form-data": + form_fields = sorted([k for k, v in (data or {}).items() if v is not None]) + file_fields: list[dict[str, str]] = [] + if files: + file_iter = files if isinstance(files, list) else list(files.items()) + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename = file_obj[0] + else: + filename = getattr(file_obj, "name", field_name) + file_fields.append({"field": field_name, "filename": str(filename or "")}) + return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + if content_type == "application/x-www-form-urlencoded": + return data or {} + return data or {} + + +async def _request_base(cfg: _RequestConfig, expect_binary: bool): + """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" + url = cfg.endpoint.path + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + + method = cfg.endpoint.method + params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + + async def _monitor(stop_evt: asyncio.Event, start_ts: float): + """Every second: update elapsed time and signal interruption.""" + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + ) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return # normal shutdown + + start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + attempt = 0 + delay = cfg.retry_delay + operation_succeeded: bool = False + final_elapsed_seconds: Optional[int] = None + while True: + attempt += 1 + stop_event = asyncio.Event() + monitor_task: Optional[asyncio.Task] = None + sess: Optional[aiohttp.ClientSession] = None + + operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) + logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) + + payload_headers = {"Accept": "*/*"} + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + payload_headers.update(get_auth_header(cfg.node_cls)) + if cfg.endpoint.headers: + payload_headers.update(cfg.endpoint.headers) + + payload_kw: dict[str, Any] = {"headers": payload_headers} + if method == "GET": + payload_headers.pop("Content-Type", None) + request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + try: + if cfg.monitor_progress: + monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) + + timeout = aiohttp.ClientTimeout(total=cfg.timeout) + sess = aiohttp.ClientSession(timeout=timeout) + + if cfg.content_type == "multipart/form-data" and method != "GET": + # aiohttp will set Content-Type boundary; remove any fixed Content-Type + payload_headers.pop("Content-Type", None) + if cfg.multipart_parser and cfg.data: + form = cfg.multipart_parser(cfg.data) + if not isinstance(form, aiohttp.FormData): + raise ValueError("multipart_parser must return aiohttp.FormData") + else: + form = aiohttp.FormData(default_to_multipart=True) + if cfg.data: + for k, v in cfg.data.items(): + if v is None: + continue + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if cfg.files: + file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename, file_value, content_type = _unpack_tuple(file_obj) + else: + filename = getattr(file_obj, "name", field_name) + file_value = file_obj + content_type = "application/octet-stream" + # Attempt to rewind BytesIO for retries + if isinstance(file_value, BytesIO): + with contextlib.suppress(Exception): + file_value.seek(0) + form.add_field(field_name, file_value, filename=filename, content_type=content_type) + payload_kw["data"] = form + elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + payload_headers["Content-Type"] = "application/x-www-form-urlencoded" + payload_kw["data"] = cfg.data or {} + elif method != "GET": + payload_headers["Content-Type"] = "application/json" + payload_kw["json"] = cfg.data or {} + + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] request logging failed: %s", _log_e) + + req_coro = sess.request(method, url, params=params, **payload_kw) + req_task = asyncio.create_task(req_coro) + + # Race: request vs. monitor (interruption) + tasks = {req_task} + if monitor_task: + tasks.add(monitor_task) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task and monitor_task in done: + # Interrupted – cancel the request and abort + if req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Task cancelled") + + # Otherwise, request finished + resp = await req_task + async with resp: + if resp.status >= 400: + try: + body = await resp.json() + except (ContentTypeError, json.JSONDecodeError): + body = await resp.text() + if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: + logging.warning( + "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", + method, + url, + resp.status, + delay, + attempt, + cfg.max_retries, + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=_friendly_http_message(resp.status, body), + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + + await sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + msg = _friendly_http_message(resp.status, body) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + raise Exception(msg) + + if expect_binary: + buff = bytearray() + last_tick = time.monotonic() + async for chunk in resp.content.iter_chunked(64 * 1024): + buff.extend(chunk) + now = time.monotonic() + if now - last_tick >= 1.0: + last_tick = now + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + ) + bytes_payload = bytes(buff) + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=bytes_payload, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return bytes_payload + else: + try: + payload = await resp.json() + response_content_to_log: Any = payload + except (ContentTypeError, json.JSONDecodeError): + text = await resp.text() + try: + payload = json.loads(text) if text else {} + except json.JSONDecodeError: + payload = {"_raw": text} + response_content_to_log = payload if isinstance(payload, dict) else text + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return payload + + except ProcessingInterrupted: + logging.debug("Polling was interrupted by user") + raise + except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + if attempt <= cfg.max_retries: + logging.warning( + "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", + method, + url, + delay, + attempt, + cfg.max_retries, + str(e), + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + except Exception as _log_e: + logging.debug("[DEBUG] request error logging failed: %s", _log_e) + await sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"LocalNetworkError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"ApiServerError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise ApiServerError( + f"The API server at {default_base_url()} is currently unreachable. " + f"The service may be experiencing issues." + ) from e + finally: + stop_event.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + _display_time_progress( + cfg.node_cls, + status=cfg.final_label_on_success, + elapsed_seconds=( + final_elapsed_seconds + if final_elapsed_seconds is not None + else int(time.monotonic() - start_time) + ), + estimated_total=cfg.estimated_total, + price=None, + is_queued=False, + processing_elapsed_seconds=final_elapsed_seconds, + ) + + +def _validate_or_raise(response_model: Type[M], payload: Any) -> M: + try: + return response_model.model_validate(payload) + except Exception as e: + logging.error( + "Response validation failed for %s: %s", + getattr(response_model, "__name__", response_model), + e, + ) + raise Exception( + f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}" + ) from e + + +def _wrap_model_extractor( + response_model: Type[M], + extractor: Optional[Callable[[M], Any]], +) -> Optional[Callable[[dict[str, Any]], Any]]: + """Wrap a typed extractor so it can be used by the dict-based poller. + Validates the dict into `response_model` before invoking `extractor`. + Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating + the same response for multiple extractors in a single poll attempt. + """ + if extractor is None: + return None + _cache: dict[int, M] = {} + + def _wrapped(d: dict[str, Any]) -> Any: + try: + key = id(d) + model = _cache.get(key) + if model is None: + model = response_model.model_validate(d) + _cache[key] = model + return extractor(model) + except Exception as e: + logging.error("Extractor failed (typed -> dict wrapper): %s", e) + raise + + return _wrapped + + +def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: + if not values: + return set() + out: set[Union[str, int]] = set() + for v in values: + nv = _normalize_status_value(v) + if nv is not None: + out.add(nv) + return out + + +def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: + if isinstance(val, str): + return val.strip().lower() + return val diff --git a/comfy_api_nodes/util/common_exceptions.py b/comfy_api_nodes/util/common_exceptions.py new file mode 100644 index 000000000000..0606a4407007 --- /dev/null +++ b/comfy_api_nodes/util/common_exceptions.py @@ -0,0 +1,14 @@ +class NetworkError(Exception): + """Base exception for network-related errors with diagnostic information.""" + + +class LocalNetworkError(NetworkError): + """Exception raised when local network connectivity issues are detected.""" + + +class ApiServerError(NetworkError): + """Exception raised when the API server is unreachable but internet is working.""" + + +class ProcessingInterrupted(Exception): + """Operation was interrupted by user/runtime via processing_interrupted().""" diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py new file mode 100644 index 000000000000..10cd1051b4d7 --- /dev/null +++ b/comfy_api_nodes/util/conversions.py @@ -0,0 +1,407 @@ +import base64 +import logging +import math +import uuid +from io import BytesIO +from typing import Optional + +import av +import numpy as np +import torch +from PIL import Image + +from comfy.utils import common_upscale +from comfy_api.latest import Input, InputImpl + +from ._helpers import mimetype_to_extension + + +def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: + """Converts image data from BytesIO to a torch.Tensor. + + Args: + image_bytesio: BytesIO object containing the image data. + mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + PIL.UnidentifiedImageError: If the image data cannot be identified. + ValueError: If the specified mode is invalid. + """ + image = Image.open(image_bytesio) + image = image.convert(mode) + image_array = np.array(image).astype(np.float32) / 255.0 + return torch.from_numpy(image_array).unsqueeze(0) + + +def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Converts a pair of image tensors to a batch tensor. + If the images are not the same size, the smaller image is resized to + match the larger image. + """ + if image1.shape[1:] != image2.shape[1:]: + image2 = common_upscale( + image2.movedim(-1, 1), + image1.shape[2], + image1.shape[1], + "bilinear", + "center", + ).movedim(1, -1) + return torch.cat((image1, image2), dim=0) + + +def tensor_to_bytesio( + image: torch.Tensor, + name: Optional[str] = None, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> BytesIO: + """Converts a torch.Tensor image to a named BytesIO object. + + Args: + image: Input torch.Tensor image. + name: Optional filename for the BytesIO object. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Named BytesIO object containing the image data, with pointer set to the start of buffer. + """ + if not mime_type: + mime_type = "image/png" + + pil_image = tensor_to_pil(image, total_pixels=total_pixels) + img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) + img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + return img_binary + + +def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: + """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" + if len(image.shape) > 3: + image = image[0] + # TODO: remove alpha if not allowed and present + input_tensor = image.cpu() + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + image_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + return img + + +def tensor_to_base64_string( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Base64 encoded string of the image. + """ + pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels) + img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type) + img_bytes = img_byte_arr.getvalue() + # Encode bytes to base64 string + base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") + return base64_encoded_string + + +def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: + """Converts a PIL Image to a BytesIO object.""" + if not mime_type: + mime_type = "image/png" + + img_byte_arr = BytesIO() + # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') + pil_format = mime_type.split("/")[-1].upper() + if pil_format == "JPG": + pil_format = "JPEG" + img.save(img_byte_arr, format=pil_format) + img_byte_arr.seek(0) + return img_byte_arr + + +def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: + """Downscale input image tensor to roughly the specified total pixels.""" + samples = image.movedim(-1, 1) + total = int(total_pixels) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + if scale_by >= 1: + return image + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = common_upscale(samples, width, height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + +def tensor_to_data_uri( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Converts a tensor image to a Data URI string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). + + Returns: + Data URI string (e.g., 'data:image/png;base64,...'). + """ + base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) + return f"data:{mime_type};base64,{base64_string}" + + +def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str: + """Converts an audio input to a base64 string.""" + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + audio_bytes = audio_bytes_io.getvalue() + return base64.b64encode(audio_bytes).decode("utf-8") + + +def audio_ndarray_to_bytesio( + audio_data_np: np.ndarray, + sample_rate: int, + container_format: str = "mp4", + codec_name: str = "aac", +) -> BytesIO: + """ + Encodes a numpy array of audio data into a BytesIO object. + """ + audio_bytes_io = BytesIO() + with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: + audio_stream = output_container.add_stream(codec_name, rate=sample_rate) + frame = av.AudioFrame.from_ndarray( + audio_data_np, + format="fltp", + layout="stereo" if audio_data_np.shape[0] > 1 else "mono", + ) + frame.sample_rate = sample_rate + frame.pts = 0 + + for packet in audio_stream.encode(frame): + output_container.mux(packet) + + # Flush stream + for packet in audio_stream.encode(None): + output_container.mux(packet) + + audio_bytes_io.seek(0) + return audio_bytes_io + + +def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: + """ + Prepares audio waveform for av library by converting to a contiguous numpy array. + + Args: + waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. + + Returns: + Contiguous numpy array of the audio waveform. If the audio was batched, + the first item is taken. + """ + if waveform.ndim != 3 or waveform.shape[0] != 1: + raise ValueError("Expected waveform tensor shape (1, channels, samples)") + + # If batch is > 1, take first item + if waveform.shape[0] > 1: + waveform = waveform[0] + + # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array + audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() + if audio_data_np.dtype != np.float32: + audio_data_np = audio_data_np.astype(np.float32) + + return audio_data_np + + +def audio_input_to_mp3(audio: Input.Audio) -> BytesIO: + waveform = audio["waveform"].cpu() + + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format="mp3") + + out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) + out_stream.bit_rate = 320000 + + frame = av.AudioFrame.from_ndarray( + waveform.movedim(0, 1).reshape(1, -1).float().numpy(), + format="flt", + layout="mono" if waveform.shape[0] == 1 else "stereo", + ) + frame.sample_rate = audio["sample_rate"] + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + output_container.mux(out_stream.encode(None)) + output_container.close() + output_buffer.seek(0) + return output_buffer + + +def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: + """ + Returns a new VideoInput object trimmed from the beginning to the specified duration, + using av to avoid loading entire video into memory. + + Args: + video: Input video to trim + duration_sec: Duration in seconds to keep from the beginning + + Returns: + VideoFromFile object that owns the output buffer + """ + output_buffer = BytesIO() + input_container = None + output_container = None + + try: + # Get the stream source - this avoids loading entire video into memory + # when the source is already a file path + input_source = video.get_stream_source() + + # Open containers + input_container = av.open(input_source, mode="r") + output_container = av.open(output_buffer, mode="w", format="mp4") + + # Set up output streams for re-encoding + video_stream = None + audio_stream = None + + for stream in input_container.streams: + logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) + if isinstance(stream, av.VideoStream): + # Create output video stream with same parameters + video_stream = output_container.add_stream("h264", rate=stream.average_rate) + video_stream.width = stream.width + video_stream.height = stream.height + video_stream.pix_fmt = "yuv420p" + logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate) + elif isinstance(stream, av.AudioStream): + # Create output audio stream with same parameters + audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) + audio_stream.sample_rate = stream.sample_rate + audio_stream.layout = stream.layout + logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) + + # Calculate target frame count that's divisible by 16 + fps = input_container.streams.video[0].average_rate + estimated_frames = int(duration_sec * fps) + target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 + + if target_frames == 0: + raise ValueError("Video too short: need at least 16 frames for Moonvalley") + + frame_count = 0 + audio_frame_count = 0 + + # Decode and re-encode video frames + if video_stream: + for frame in input_container.decode(video=0): + if frame_count >= target_frames: + break + + # Re-encode frame + for packet in video_stream.encode(frame): + output_container.mux(packet) + frame_count += 1 + + # Flush encoder + for packet in video_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) + + # Decode and re-encode audio frames + if audio_stream: + input_container.seek(0) # Reset to beginning for audio + for frame in input_container.decode(audio=0): + if frame.time >= duration_sec: + break + + # Re-encode frame + for packet in audio_stream.encode(frame): + output_container.mux(packet) + audio_frame_count += 1 + + # Flush encoder + for packet in audio_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s audio frames", audio_frame_count) + + # Close containers + output_container.close() + input_container.close() + + # Return as VideoFromFile using the buffer + output_buffer.seek(0) + return InputImpl.VideoFromFile(output_buffer) + + except Exception as e: + # Clean up on error + if input_container is not None: + input_container.close() + if output_container is not None: + output_container.close() + raise RuntimeError(f"Failed to trim video: {str(e)}") from e + + +def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2**15) + elif wav.dtype == torch.int32: + return wav.float() / (2**31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict: + """ + Decode any common audio container from bytes using PyAV and return + a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. + """ + with av.open(BytesIO(audio_bytes)) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in response.") + stream = af.streams.audio[0] + + in_sr = int(stream.codec_context.sample_rate) + out_sr = in_sr + + frames: list[torch.Tensor] = [] + n_channels = stream.channels or 1 + + for frame in af.decode(streams=stream.index): + arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] + buf = torch.from_numpy(arr) + if buf.ndim == 1: + buf = buf.unsqueeze(0) # [T] -> [1, T] + elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: + buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] + elif buf.shape[0] != n_channels: + buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] + frames.append(buf) + + if not frames: + raise ValueError("Decoded zero audio frames.") + + wav = torch.cat(frames, dim=1) # [C, T] + wav = _f32_pcm(wav) + return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py new file mode 100644 index 000000000000..055e690de4e5 --- /dev/null +++ b/comfy_api_nodes/util/download_helpers.py @@ -0,0 +1,249 @@ +import asyncio +import contextlib +import uuid +from io import BytesIO +from pathlib import Path +from typing import IO, Optional, Union +from urllib.parse import urljoin, urlparse + +import aiohttp +import torch +from aiohttp.client_exceptions import ClientError, ContentTypeError + +from comfy_api.input_impl import VideoFromFile +from comfy_api.latest import IO as COMFY_IO +from comfy_api_nodes.apis import request_logger + +from ._helpers import ( + default_base_url, + get_auth_header, + is_processing_interrupted, + sleep_with_interrupt, +) +from .client import _diagnose_connectivity +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import bytesio_to_image_tensor + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} + + +async def download_url_to_bytesio( + url: str, + dest: Optional[Union[BytesIO, IO[bytes], str, Path]], + *, + timeout: Optional[float] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + cls: type[COMFY_IO.ComfyNode] = None, +) -> None: + """Stream-download a URL to `dest`. + + `dest` must be one of: + - a BytesIO (rewound to 0 after write), + - a file-like object opened in binary write mode (must implement .write()), + - a filesystem path (str | pathlib.Path), which will be opened with 'wb'. + + If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded + to an absolute URL and authentication headers can be applied. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) + """ + if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"): + raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().") + + attempt = 0 + delay = retry_delay + headers: dict[str, str] = {} + + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + if cls is None: + raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + headers = get_auth_header(cls) + + while True: + attempt += 1 + op_id = _generate_operation_id("GET", url, attempt) + timeout_cfg = aiohttp.ClientTimeout(total=timeout) + + is_path_sink = isinstance(dest, (str, Path)) + fhandle = None + session: Optional[aiohttp.ClientSession] = None + stop_evt: Optional[asyncio.Event] = None + monitor_task: Optional[asyncio.Task] = None + req_task: Optional[asyncio.Task] = None + + try: + with contextlib.suppress(Exception): + request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) + + session = aiohttp.ClientSession(timeout=timeout_cfg) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + + req_task = asyncio.create_task(session.get(url, headers=headers)) + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + raise ProcessingInterrupted("Task cancelled") + + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except (ContentTypeError, ValueError): + text = await resp.text() + body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=f"HTTP {resp.status}", + ) + + if resp.status in _RETRY_STATUS and attempt <= max_retries: + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + raise Exception(f"Failed to download (HTTP {resp.status}).") + + if is_path_sink: + p = Path(str(dest)) + with contextlib.suppress(Exception): + p.parent.mkdir(parents=True, exist_ok=True) + fhandle = open(p, "wb") + sink = fhandle + else: + sink = dest # BytesIO or file-like + + written = 0 + while True: + try: + chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0) + except asyncio.TimeoutError: + chunk = b"" + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + if not chunk: + if resp.content.at_eof(): + break + continue + + sink.write(chunk) + written += len(chunk) + + if isinstance(dest, BytesIO): + with contextlib.suppress(Exception): + dest.seek(0) + + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=f"[streamed {written} bytes to dest]", + ) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + except (ClientError, asyncio.TimeoutError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The remote service appears unreachable at this time.") from e + finally: + if stop_evt is not None: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if req_task and not req_task.done(): + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + if session: + with contextlib.suppress(Exception): + await session.close() + if fhandle: + with contextlib.suppress(Exception): + fhandle.flush() + fhandle.close() + + +async def download_url_to_image_tensor( + url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> torch.Tensor: + """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return bytesio_to_image_tensor(result) + + +async def download_url_to_video_output( + video_url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> VideoFromFile: + """Downloads a video from a URL and returns a `VIDEO` output.""" + result = BytesIO() + await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls) + return VideoFromFile(result) + + +def _generate_operation_id(method: str, url: str, attempt: int) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_") + except Exception: + slug = "download" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py new file mode 100644 index 000000000000..a345d451d4cd --- /dev/null +++ b/comfy_api_nodes/util/upload_helpers.py @@ -0,0 +1,338 @@ +import asyncio +import contextlib +import logging +import time +import uuid +from io import BytesIO +from typing import Optional, Union +from urllib.parse import urlparse + +import aiohttp +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import IO, Input +from comfy_api.util import VideoCodec, VideoContainer +from comfy_api_nodes.apis import request_logger + +from ._helpers import is_processing_interrupted, sleep_with_interrupt +from .client import ( + ApiEndpoint, + _diagnose_connectivity, + _display_time_progress, + sync_op, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import ( + audio_ndarray_to_bytesio, + audio_tensor_to_contiguous_ndarray, + tensor_to_bytesio, +) + + +class UploadRequest(BaseModel): + file_name: str = Field(..., description="Filename to upload") + content_type: Optional[str] = Field( + None, + description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", + ) + + +class UploadResponse(BaseModel): + download_url: str = Field(..., description="URL to GET uploaded file") + upload_url: str = Field(..., description="URL to PUT file to upload") + + +async def upload_images_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + max_images: int = 8, + mime_type: Optional[str] = None, + wait_label: Optional[str] = "Uploading", +) -> list[str]: + """ + Uploads images to ComfyUI API and returns download URLs. + To upload multiple images, stack them in the batch dimension first. + """ + # if batch, try to upload each file if max_images is greater than 0 + download_urls: list[str] = [] + is_batch = len(image.shape) > 3 + batch_len = image.shape[0] if is_batch else 1 + + for idx in range(min(batch_len, max_images)): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label) + download_urls.append(url) + return download_urls + + +async def upload_audio_to_comfyapi( + cls: type[IO.ComfyNode], + audio: Input.Audio, + *, + container_format: str = "mp4", + codec_name: str = "aac", + mime_type: str = "audio/mp4", + filename: str = "uploaded_audio.mp4", +) -> str: + """ + Uploads a single audio input to ComfyUI API and returns its download URL. + Encodes the raw waveform into the specified format before uploading. + """ + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + + +async def upload_video_to_comfyapi( + cls: type[IO.ComfyNode], + video: Input.Video, + *, + container: VideoContainer = VideoContainer.MP4, + codec: VideoCodec = VideoCodec.H264, + max_duration: Optional[int] = None, +) -> str: + """ + Uploads a single video to ComfyUI API and returns its download URL. + Uses the specified container and codec for saving the video before upload. + """ + if max_duration is not None: + try: + actual_duration = video.get_duration() + if actual_duration > max_duration: + raise ValueError( + f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." + ) + except Exception as e: + logging.error("Error getting video duration: %s", str(e)) + raise ValueError(f"Could not verify video duration from source: {e}") from e + + upload_mime_type = f"video/{container.value.lower()}" + filename = f"uploaded_video.{container.value.lower()}" + + # Convert VideoInput to BytesIO using specified container/codec + video_bytes_io = BytesIO() + video.save_to(video_bytes_io, format=container, codec=codec) + video_bytes_io.seek(0) + + return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) + + +async def upload_file_to_comfyapi( + cls: type[IO.ComfyNode], + file_bytes_io: BytesIO, + filename: str, + upload_mime_type: Optional[str], + wait_label: Optional[str] = "Uploading", +) -> str: + """Uploads a single file to ComfyUI API and returns its download URL.""" + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + create_resp = await sync_op( + cls, + endpoint=ApiEndpoint(path="/customers/storage", method="POST"), + data=request_object, + response_model=UploadResponse, + final_label_on_success=None, + monitor_progress=False, + ) + await upload_file( + cls, + create_resp.upload_url, + file_bytes_io, + content_type=upload_mime_type, + wait_label=wait_label, + ) + return create_resp.download_url + + +async def upload_file( + cls: type[IO.ComfyNode], + upload_url: str, + file: Union[BytesIO, str], + *, + content_type: Optional[str] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: Optional[str] = None, +) -> None: + """ + Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. + + Args: + cls: Node class (provides auth context + UI progress hooks). + upload_url: Pre-signed PUT URL. + file: BytesIO or path string. + content_type: Explicit MIME type. If None, we *suppress* Content-Type. + max_retries: Maximum retry attempts. + retry_delay: Initial delay in seconds. + retry_backoff: Exponential backoff factor. + wait_label: Progress label shown in Comfy UI. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception + """ + if isinstance(file, BytesIO): + with contextlib.suppress(Exception): + file.seek(0) + data = file.read() + elif isinstance(file, str): + with open(file, "rb") as f: + data = f.read() + else: + raise ValueError("file must be a BytesIO or a filesystem path string") + + headers: dict[str, str] = {} + skip_auto_headers: set[str] = set() + if content_type: + headers["Content-Type"] = content_type + else: + skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request + + attempt = 0 + delay = retry_delay + start_ts = time.monotonic() + op_uuid = uuid.uuid4().hex[:8] + while True: + attempt += 1 + operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) + timeout = aiohttp.ClientTimeout(total=None) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + if wait_label: + _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + sess: Optional[aiohttp.ClientSession] = None + try: + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_params=None, + request_data=f"[File data {len(data)} bytes]", + ) + except Exception as e: + logging.debug("[DEBUG] upload request logging failed: %s", e) + + sess = aiohttp.ClientSession(timeout=timeout) + req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) + req_task = asyncio.create_task(req) + + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Upload cancelled") + + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Upload cancelled") from None + + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except Exception: + body = await resp.text() + msg = f"Upload failed with status {resp.status}" + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: + await sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + raise Exception(f"Failed to upload (HTTP {resp.status}).") + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + except Exception as e: + logging.debug("[DEBUG] upload response logging failed: %s", e) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_data=f"[File data {len(data)} bytes]", + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The API service appears unreachable at this time.") from e + finally: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + + +def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") + except Exception: + slug = "upload" + return f"{method}_{slug}_{op_uuid}_try{attempt}" diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index ca913e9b3eed..22da05bc199d 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -2,6 +2,8 @@ from typing import Optional import torch + +from comfy_api.input.video_types import VideoInput from comfy_api.latest import Input @@ -28,9 +30,7 @@ def validate_image_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Image width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Image height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Image height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Image height must be at most {max_height}px, got {height}px") @@ -44,13 +44,9 @@ def validate_image_aspect_ratio( aspect_ratio = width / height if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}" - ) + raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}") if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}" - ) + raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}") def validate_image_aspect_ratio_range( @@ -58,7 +54,7 @@ def validate_image_aspect_ratio_range( min_ratio: tuple[float, float], # e.g. (1, 4) max_ratio: tuple[float, float], # e.g. (4, 1) *, - strict: bool = True, # True -> (min, max); False -> [min, max] + strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: a1, b1 = min_ratio a2, b2 = max_ratio @@ -85,7 +81,7 @@ def validate_aspect_ratio_closeness( min_rel: float, max_rel: float, *, - strict: bool = False, # True => exclusive, False => inclusive + strict: bool = False, # True => exclusive, False => inclusive ) -> None: w1, h1 = get_image_dimensions(start_img) w2, h2 = get_image_dimensions(end_img) @@ -118,9 +114,7 @@ def validate_video_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Video width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Video height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Video height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Video height must be at most {max_height}px, got {height}px") @@ -138,13 +132,9 @@ def validate_video_duration( epsilon = 0.0001 if min_duration is not None and min_duration - epsilon > duration: - raise ValueError( - f"Video duration must be at least {min_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s") if max_duration is not None and duration > max_duration + epsilon: - raise ValueError( - f"Video duration must be at most {max_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") def get_number_of_images(images): @@ -165,3 +155,31 @@ def validate_audio_duration( raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s") if max_duration is not None and dur - eps > max_duration: raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s") + + +def validate_string( + string: str, + strip_whitespace=True, + field_name="prompt", + min_length=None, + max_length=None, +): + if string is None: + raise Exception(f"Field '{field_name}' cannot be empty.") + if strip_whitespace: + string = string.strip() + if min_length and len(string) < min_length: + raise Exception( + f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." + ) + if max_length and len(string) > max_length: + raise Exception( + f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." + ) + + +def validate_container_format_is_mp4(video: VideoInput) -> None: + """Validates video container format is MP4.""" + container_format = video.get_container_format() + if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: + raise ValueError(f"Only MP4 container format supported. Got: {container_format}") diff --git a/pyproject.toml b/pyproject.toml index 0c6b23a253d7..fcc4854a576b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ messages_control.disable = [ "too-many-branches", "too-many-locals", "too-many-arguments", + "too-many-return-statements", + "too-many-nested-blocks", "duplicate-code", "abstract-method", "superfluous-parens", From 5e9f33575367f0d6ac897665ed6dd99616b33d28 Mon Sep 17 00:00:00 2001 From: lspindler Date: Fri, 24 Oct 2025 14:44:54 +0200 Subject: [PATCH 28/49] An actually functional POC --- comfy/model_detection.py | 123 ++-------- comfy/ops.py | 268 ++++++++------------- comfy/quant_ops.py | 494 ++++++++++++++++++++++++++++----------- comfy/sd.py | 8 +- 4 files changed, 468 insertions(+), 425 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 01f26836b8d4..ffb1885fd1b4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -7,121 +7,24 @@ import torch -# ============================================================================== -# Quantization Detection Functions -# ============================================================================== - -def normalize_layer_name(full_key, known_prefixes): - """ - Strip model prefix and parameter suffix from a state dict key. - - Args: - full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight") - known_prefixes: List of known model prefixes to strip - - Returns: - Normalized layer name (e.g., "layer1") - """ - name = full_key - - # Strip model prefix - for prefix in known_prefixes: - if name.startswith(prefix): - name = name[len(prefix):] - break - - # Remove parameter suffix - for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]: - if name.endswith(suffix): - name = name[:-len(suffix)] - break - - return name - - -def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): - """ - Detect per-layer quantization configuration from state dict. - - Detection priority: - 1. Check for _quantization_metadata key (new format) - 2. Check for scaled_fp8 key (legacy format - return None) - 3. Check for per-layer scale_weight patterns (mixed detection) - 4. No quantization detected (return None) - - Args: - state_dict: Model state dictionary - prefix: Key prefix for model layers - - Returns: - Dict mapping layer names to quantization configs, or None for legacy/no quantization. - - Example return value: - { - "input_blocks.5.1.transformer_blocks.0.attn1.to_q": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": True} - }, - "middle_block.1.transformer_blocks.0.attn2.to_k": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": True} - } - } - """ - - # 1. Check for new metadata format - metadata_key = f"{prefix}_quantization_metadata" - if metadata_key in state_dict: - try: - metadata = state_dict.pop(metadata_key) - if isinstance(metadata, dict) and "layers" in metadata: - logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})") - return metadata["layers"] - else: - logging.warning(f"Invalid quantization metadata format, ignoring") - except Exception as e: - logging.error(f"Failed to parse quantization metadata: {e}") - return None +def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."): + # 1. Check for per-layer config in metadata + quant_key = "_quantization_metadata" + if metadata is not None and quant_key in metadata: + quant_metadata = metadata.pop(quant_key) + quant_metadata = json.loads(quant_metadata) + if isinstance(quant_metadata, dict) and "layers" in quant_metadata: + logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") + return quant_metadata["layers"] + else: + raise ValueError(f"Invalid quantization metadata format") # 2. Check for legacy scaled_fp8 marker - # If present, return None to use legacy code path scaled_fp8_key = f"{prefix}scaled_fp8" if scaled_fp8_key in state_dict: logging.debug("Detected legacy scaled_fp8 format, using legacy code path") return None - - # 3. Check for per-layer scale patterns (mixed precision without metadata) - # Look for layers that have scale_weight but not all layers have it - known_prefixes = [prefix] - layer_configs = {} - layers_with_scale = set() - layers_with_weight = set() - - for key in state_dict.keys(): - if key.startswith(prefix): - if key.endswith(".scale_weight"): - layer_name = normalize_layer_name(key, known_prefixes) - layers_with_scale.add(layer_name) - # Detect format based on weight dtype - weight_key = f"{prefix}{layer_name}.weight" - if weight_key in state_dict: - weight_dtype = state_dict[weight_key].dtype - if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2" - layer_configs[layer_name] = { - "format": format_name, - "params": {} - } - elif key.endswith(".weight") and not key.endswith(".scale_weight"): - layer_name = normalize_layer_name(key, known_prefixes) - layers_with_weight.add(layer_name) - - # If we found scale_weight on some but not all layers, it's mixed precision - if layer_configs and len(layers_with_scale) < len(layers_with_weight): - logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers") - return layer_configs - - # 4. No quantization detected + return None @@ -821,7 +724,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.optimizations["fp8"] = True # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix) + layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix) if layer_quant_config: model_config.layer_quant_config = layer_quant_config logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") diff --git a/comfy/ops.py b/comfy/ops.py index 060b35137f21..8d11aeefc0af 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -325,19 +325,8 @@ class Embedding(disable_weight_init.Embedding): def fp8_linear(self, input): """ - Legacy FP8 linear function - now uses tensor subclass infrastructure. - - This function maintains backward compatibility with existing code while - routing all FP8 computation through the unified tensor subclass system. - All actual FP8 matmul logic is handled by the registered operation handlers - in quant_ops.py via __torch_dispatch__. - - Args: - self: Linear layer with FP8 weight and scale parameters - input: Input tensor (any dtype) - - Returns: - Output tensor or None if weight is not FP8 + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: @@ -352,10 +341,8 @@ def fp8_linear(self, input): input_dtype = input.dtype if len(input.shape) == 3: - # Get weight and bias using standard casting w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - # Get scales (same as before) scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: @@ -368,14 +355,13 @@ def fp8_linear(self, input): else: scale_input = scale_input.to(input.device) - # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch - quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) - quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) - # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! - # This is the key unification: all FP8 computation goes through one path + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - # Reshape output if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -472,183 +458,117 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -# Import quantization operations from separate module -from .quant_ops import QuantizedTensorFP8 - - # ============================================================================== # Mixed Precision Operations # ============================================================================== +from .quant_ops import QuantizedTensor, TensorCoreFP8Layout + +QUANT_FORMAT_MIXINS = { + "float8_e4m3fn": { + "dtype": torch.float8_e4m3fn, + "layout_type": TensorCoreFP8Layout, + "parameters": { + "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + } + } +} class MixedPrecisionOps(disable_weight_init): - """ - Operations class supporting per-layer quantization (mixed precision). - - This class enables different layers to use different quantization formats - within the same model (e.g., some layers FP8, others BF16). - - Layer-specific quantization is configured via _layer_quant_config class variable, - which is set by pick_operations() when a model has mixed precision. - """ - - _layer_quant_config = {} # Class variable set by pick_operations() - - class Linear(disable_weight_init.Linear): - """Linear layer with optional per-layer quantization using tensor subclasses""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.quant_format = None - self.quant_scale = None - self._quantization_initialized = False - + _layer_quant_config = {} + _compute_dtype = torch.bfloat16 + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.tensor_class = None + def reset_parameters(self): - # Don't allocate weights - return None like disable_weight_init return None def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - """ - Called by PyTorch during load_state_dict. - Load weight and wrap in QuantizedTensorFP8 if this layer is quantized. - """ - # Call parent to load weight and bias first - super()._load_from_state_dict( - state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs - ) - - # After weight is loaded, wrap it if this layer is quantized - if not self._quantization_initialized: - # Normalize layer name from prefix - layer_name = prefix.rstrip('.') - - # Strip known model prefixes - for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: - if layer_name.startswith(model_prefix): - layer_name = layer_name[len(model_prefix):] - break - - # Check if this layer has quantization config - if layer_name in MixedPrecisionOps._layer_quant_config: - config = MixedPrecisionOps._layer_quant_config[layer_name] - self.quant_format = config.get("format", "fp8_e4m3fn") - - # Load scale parameter - scale_key = f"{prefix}scale_weight" - if scale_key in state_dict: - self.quant_scale = state_dict[scale_key] - - # Wrap weight in QuantizedTensorFP8 - if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - try: - # Determine original dtype (default to bfloat16) - orig_dtype = torch.bfloat16 - - # Wrap weight in quantized tensor subclass - quantized_weight = QuantizedTensorFP8( - self.weight.data, - self.quant_scale, - orig_dtype=orig_dtype - ) - - # Replace weight parameter with wrapped version - self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - - logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})") - except Exception as e: - logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}") - self.quant_format = None - self.quant_scale = None - else: - logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization") - self.quant_format = None - self.quant_scale = None - else: - logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict") - self.quant_format = None - - self._quantization_initialized = True - - def _save_to_state_dict(self, destination, prefix, keep_vars): - """Save layer parameters including quantization scale""" - # First unwrap the weight if it's quantized - if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8): - # Temporarily unwrap to save the raw FP8 data - quantized_tensor = self.weight.data - raw_fp8_data = quantized_tensor._raw_data - original_weight = self.weight - self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False) - - # Call parent to save unwrapped weight - super()._save_to_state_dict(destination, prefix, keep_vars) + + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") - # Restore the wrapped weight - self.weight = original_weight + mixin = QUANT_FORMAT_MIXINS[quant_format] + self.layout_type = mixin["layout_type"] - # Save the scale parameter - if self.quant_scale is not None: - destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach() - else: - # Standard path for non-quantized weights - super()._save_to_state_dict(destination, prefix, keep_vars) - + layout_params = { + 'scale': state_dict.pop(f"{prefix}weight_scale", None), + 'orig_dtype': MixedPrecisionOps._compute_dtype + } + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), + requires_grad=False + ) + + for param_name, param_value in mixin["parameters"].items(): + _v = state_dict.pop(f"{prefix}{param_name}", None) + if _v is None: + continue + setattr(self, param_name, _v.to(device=device)) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + def forward_comfy_cast_weights(self, input): - """ - Forward pass - tensor subclass handles dispatch automatically! - __torch_dispatch__ will route to registered handlers based on tensor types. - """ weight, bias = cast_bias_weight(self, input) - - # Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it! - return torch.nn.functional.linear(input, weight, bias) - - def forward(self, *args, **kwargs): - """Main forward pass""" + self._forward(input, weight, bias) + + def forward(self, input, *args, **kwargs): run_every_op() - # Same logic as disable_weight_init.Linear + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) - @classmethod - def conv_nd(s, dims, *args, **kwargs): - """Create Conv layer (same as disable_weight_init)""" - if dims == 2: - return s.Conv2d(*args, **kwargs) - elif dims == 3: - return s.Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): - """ - Select appropriate operations class for model. - - NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3). - LEGACY: All other paths unchanged for backward compatibility. - - Args: - weight_dtype: Weight storage dtype - compute_dtype: Computation dtype - load_device: Device for loading - disable_fast_fp8: Disable fast FP8 paths - fp8_optimizations: Enable FP8 optimizations - scaled_fp8: Legacy FP8 dtype marker - model_config: Model config object (optional, for mixed precision support) - - Returns: - Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init) - """ - # NEW: Check for mixed precision + # If model_config.layer_quant_config exists, use new MixedPrecisionOps. if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + # MixedPrecisionOps._compute_dtype = compute_dtype # TODO + MixedPrecisionOps._compute_dtype = torch.bfloat16 logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps - # LEGACY paths (unchanged) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8e3bacbaf8af..3802da8524e7 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,42 +1,79 @@ import torch import logging +from typing import Tuple, Dict -# ============================================================================== -# Global Operation Registry -# ============================================================================== +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} -# Global operation registry: torch operation → handler function -_QUANT_OP_REGISTRY = {} -def register_quant_op(torch_op): +def register_layout_op(torch_op, layout_type): """ - Decorator to register an operation handler. + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) + Example: + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation + ... + """ + def decorator(handler_func): + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func + return handler_func + return decorator + + +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) Example: - @register_quant_op(torch.ops.aten.linear.default) - def handle_linear_fp8(func, args, kwargs): - # Implementation + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout ... """ def decorator(handler_func): - _QUANT_OP_REGISTRY[torch_op] = handler_func + _GENERIC_UTILS[torch_op] = handler_func return handler_func return decorator -def get_quant_handler(torch_op): - """Get registered handler for an operation""" - return _QUANT_OP_REGISTRY.get(torch_op) +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None -def list_registered_ops(): - """List all registered quantized operations""" - return list(_QUANT_OP_REGISTRY.keys()) +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params -# ============================================================================== -# comfy_kitchen Integration -# ============================================================================== +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + try: import comfy_kitchen as ck @@ -53,106 +90,144 @@ def list_registered_ops(): logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") -# ============================================================================== -# Quantized Tensor Subclass -# ============================================================================== +class QuantizedLayout: + """ + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") -class QuantizedTensorFP8(torch.Tensor): + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError(f"TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): """ - Tensor subclass for FP8 quantized data. - Automatically handles operations via __torch_dispatch__. + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ @staticmethod - def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16): + def __new__(cls, qdata, layout_type, layout_params): """ - Create a quantized FP8 tensor. + Create a quantized tensor. Args: - tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2) - scale: Scale factor for dequantization (scalar tensor) - orig_dtype: Original dtype before quantization + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_subclass(cls, tensor, require_grad=False) + return torch.Tensor._make_subclass(cls, qdata, require_grad=False) - def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): - self._scale = scale - self._orig_dtype = orig_dtype - # Store a reference to prevent infinite recursion in dequantize - self._raw_data = tensor.contiguous() + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata.contiguous() + self._layout_type = layout_type + self._layout_params = layout_params def __repr__(self): - return (f"QuantizedTensorFP8(shape={self.shape}, " - f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + layout_name = self._layout_type.__name__ + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - @classmethod - def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): - orig_dtype = tensor.dtype + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_q_data"] + ctx = { + "layout_type": self._layout_type, + } - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v - tensor_fp8 = None - if _CK_AVAILABLE: - try: - tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) - except Exception as e: - logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}") + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params - if tensor_fp8 is None: - lp_amax = torch.finfo(fp8_dtype).max - tensor_scaled = tensor.float() / scale - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] - return cls(tensor_fp8, scale, orig_dtype=orig_dtype) + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) @classmethod - def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn): - if strategy == "amax": - scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max - scale = scale.to(tensor.device, dtype=torch.float32) - else: - raise ValueError(f"Unknown quantization strategy: {strategy}. " - f"Supported: 'amax'") - - return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype) + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) + + def dequantize(self) -> torch.Tensor: + return self._layout_type.dequantize(self._qdata, **self._layout_params) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} - # Special case: skip dispatch for internal tensor operations - # that are used for unwrapping (to avoid recursion) - if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]: - # For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach - if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8): - # Special handling for detach - return a new QuantizedTensorFP8 - qt = args[0] - detached_data = qt._raw_data.detach() - return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype) - - # For other ops, just unwrap - def unwrap(arg): - if isinstance(arg, QuantizedTensorFP8): - return arg._raw_data - return arg - new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args) - return func(*new_args, **kwargs) + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) - # Look up registered handler for this operation - handler = _QUANT_OP_REGISTRY.get(func) - if handler: - return handler(func, args, kwargs) + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) - # No handler - dequantize and use standard path + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) @classmethod def _dequant_and_fallback(cls, func, args, kwargs): - """Fallback: dequantize all quantized tensors""" def dequant_arg(arg): - if isinstance(arg, QuantizedTensorFP8): + if isinstance(arg, QuantizedTensor): return arg.dequantize() elif isinstance(arg, (list, tuple)): return type(arg)(dequant_arg(a) for a in arg) @@ -161,75 +236,220 @@ def dequant_arg(arg): new_args = dequant_arg(args) new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) + + +# ============================================================================== +# Generic Utilities (Layout-Agnostic Operations) +# ============================================================================== + +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_dtype is not None and target_dtype != qt.dtype: + logging.warning( + f"QuantizedTensor: dtype conversion requested to {target_dtype}, " + f"but not supported for quantized tensors. Ignoring dtype." + ) - def dequantize(self) -> torch.Tensor: - plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) - return plain_tensor * self._scale + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt - def detach(self): - """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" - detached_data = self._raw_data.detach() - return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata) + qt_dest._layout_type = src._layout_type + qt_dest._layout_params = _copy_layout_params(src._layout_params) + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True # ============================================================================== -# Operation Handlers for Quantized Tensors +# FP8 Layout + Operation Handlers # ============================================================================== - -@register_quant_op(torch.ops.aten.linear.default) -def handle_linear_fp8(func, args, kwargs): +class TensorCoreFP8Layout(QuantizedLayout): """ - Handle F.linear() with quantized inputs. - - Supports: - - QuantizedTensorFP8 input + QuantizedTensorFP8 weight - - QuantizedTensorFP8 input + regular weight - - Regular input + QuantizedTensorFP8 weight + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) """ + @classmethod + def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if scale is None: + scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + if _CK_AVAILABLE and tensor.device.type == "cuda": + qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) + else: + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return qdata, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + return plain_tensor * scale + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + + +@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype) - - # Case 1: Both input and weight are FP8 - if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): - # Get plain tensors to avoid dispatch recursion - plain_input = input_tensor._raw_data - plain_weight = weight._raw_data - weight_t = plain_weight.t() # Keep as column-major for cuBLASLt + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + try: output = torch._scaled_mm( - plain_input, + plain_input.reshape(-1, input_shape[2]), weight_t, bias=bias, - scale_a=input_tensor._scale, - scale_b=weight._scale, + scale_a=scale_a, + scale_b=scale_b, out_dtype=out_dtype, ) - if isinstance(output, tuple): - output = output[0] - + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = input_tensor._scale * weight._scale - return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, TensorCoreFP8Layout, output_params) else: return output + except Exception as e: - logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - # Case 2: Only weight is quantized - if isinstance(weight, QuantizedTensorFP8): - weight_dq = weight.dequantize() - input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor - return torch.nn.functional.linear(input_dq, weight_dq, bias) - - # Case 3: Only input is quantized - elif isinstance(input_tensor, QuantizedTensorFP8): - input_dq = input_tensor.dequantize() - return torch.nn.functional.linear(input_dq, weight, bias) - - # Case 4: Neither is quantized (shouldn't happen, but handle it) - else: - return torch.nn.functional.linear(input_tensor, weight, bias) + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248dae1..b965e98427d9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): weight_dtype = comfy.utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd @@ -1346,8 +1346,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - sd = comfy.utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options) + sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) From dd5af0c5871376c377b2e30f9725b67a768eea6f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 25 Oct 2025 01:48:34 +0300 Subject: [PATCH 29/49] convert Tripo API nodes to V3 schema (#10469) --- comfy_api_nodes/apis/tripo_api.py | 15 +- comfy_api_nodes/nodes_tripo.py | 906 ++++++++++++----------- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/download_helpers.py | 12 + 4 files changed, 510 insertions(+), 425 deletions(-) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py index 9f43d4d0901b..713260e2a4f4 100644 --- a/comfy_api_nodes/apis/tripo_api.py +++ b/comfy_api_nodes/apis/tripo_api.py @@ -1,13 +1,20 @@ from __future__ import annotations -from comfy_api_nodes.apis import ( - TripoModelVersion, - TripoTextureQuality, -) from enum import Enum from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, RootModel +class TripoModelVersion(str, Enum): + v2_5_20250123 = 'v2.5-20250123' + v2_0_20240919 = 'v2.0-20240919' + v1_4_20240625 = 'v1.4-20240625' + + +class TripoTextureQuality(str, Enum): + standard = 'standard' + detailed = 'detailed' + + class TripoStyle(str, Enum): PERSON_TO_CARTOON = "person:person2cartoon" ANIMAL_VENOM = "animal:venom" diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index d08cf9007655..697100ff2662 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -1,46 +1,39 @@ import os -from folder_paths import get_output_directory -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy.comfy_types.node_typing import IO -from comfy_api_nodes.apis import ( - TripoOrientation, - TripoModelVersion, -) +from typing import Optional + +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.tripo_api import ( - TripoTaskType, - TripoStyle, - TripoFileReference, + TripoAnimateRetargetRequest, + TripoAnimateRigRequest, + TripoConvertModelRequest, TripoFileEmptyReference, - TripoUrlReference, + TripoFileReference, + TripoImageToModelRequest, + TripoModelVersion, + TripoMultiviewToModelRequest, + TripoOrientation, + TripoRefineModelRequest, + TripoStyle, TripoTaskResponse, TripoTaskStatus, + TripoTaskType, TripoTextToModelRequest, - TripoImageToModelRequest, - TripoMultiviewToModelRequest, TripoTextureModelRequest, - TripoRefineModelRequest, - TripoAnimateRigRequest, - TripoAnimateRetargetRequest, - TripoConvertModelRequest, + TripoUrlReference, ) - -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_as_bytesio, + poll_op, + sync_op, upload_images_to_comfyapi, - download_url_to_bytesio, ) +from folder_paths import get_output_directory -async def upload_image_to_tripo(image, **kwargs): - urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) - return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) - def get_model_url_from_response(response: TripoTaskResponse) -> str: if response.data is not None: for key in ["pbr_model", "model", "base_model"]: @@ -50,20 +43,18 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: async def poll_until_finished( - kwargs: dict[str, str], + node_cls: type[IO.ComfyNode], response: TripoTaskResponse, -) -> tuple[str, str]: + average_duration: Optional[int] = None, +) -> IO.NodeOutput: """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" if response.code != 0: raise RuntimeError(f"Failed to generate mesh: {response.error}") task_id = response.data.task_id - response_poll = await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/tripo/v2/openapi/task/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TripoTaskResponse, - ), + response_poll = await poll_op( + node_cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"), + response_model=TripoTaskResponse, completed_statuses=[TripoTaskStatus.SUCCESS], failed_statuses=[ TripoTaskStatus.FAILED, @@ -73,72 +64,84 @@ async def poll_until_finished( TripoTaskStatus.EXPIRED, ], status_extractor=lambda x: x.data.status, - auth_kwargs=kwargs, - node_id=kwargs["unique_id"], - result_url_extractor=get_model_url_from_response, progress_extractor=lambda x: x.data.progress, - ).execute() + estimated_duration=average_duration, + ) if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = await download_url_to_bytesio(url) + bytesio = await download_url_as_bytesio(url) # Save the downloaded model file model_file = f"tripo_model_{task_id}.glb" with open(os.path.join(get_output_directory(), model_file), "wb") as f: f.write(bytesio.getvalue()) - return model_file, task_id + return IO.NodeOutput(model_file, task_id) raise RuntimeError(f"Failed to generate mesh: {response_poll}") -class TripoTextToModelNode: +class TripoTextToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on a text prompt using Tripo's API. """ - AVERAGE_DURATION = 80 + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoTextToModelNode", + display_name="Tripo: Text to Model", + category="api node/3d/Tripo", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.String.Input("negative_prompt", multiline=True, optional=True), + IO.Combo.Input( + "model_version", options=TripoModelVersion, default=TripoModelVersion.v2_5_20250123, optional=True + ), + IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("image_seed", default=42, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ("STRING", {"multiline": True}), - }, - "optional": { - "negative_prompt": ("STRING", {"multiline": True}), - "model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion), - "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "image_seed": ("INT", {"default": 42}), - "model_seed": ("INT", {"default": 42}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + async def execute( + cls, + prompt: str, + negative_prompt: Optional[str] = None, + model_version=None, + style: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + image_seed: Optional[int] = None, + model_seed: Optional[int] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: style_enum = None if style == "None" else style if not prompt: raise RuntimeError("Prompt is required") - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoTextToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoTextToModelRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoTextToModelRequest( type=TripoTaskType.TEXT_TO_MODEL, prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, @@ -152,64 +155,89 @@ async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, texture_quality=texture_quality, face_limit=face_limit, auto_size=True, - quad=quad + quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoImageToModelNode: +class TripoImageToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on a single image using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - }, - "optional": { - "model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion), - "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "model_seed": ("INT", {"default": 42}), - "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + def define_schema(cls): + return IO.Schema( + node_id="TripoImageToModelNode", + display_name="Tripo: Image to Model", + category="api node/3d/Tripo", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "model_version", + options=TripoModelVersion, + tooltip="The model version to use for generation", + optional=True, + ), + IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Combo.Input( + "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True + ), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + async def execute( + cls, + image: torch.Tensor, + model_version: Optional[str] = None, + style: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + model_seed: Optional[int] = None, + orientation=None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: style_enum = None if style == "None" else style if image is None: raise RuntimeError("Image is required") - tripo_file = await upload_image_to_tripo(image, **kwargs) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoImageToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoImageToModelRequest( + tripo_file = TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + type="jpeg", + ) + ) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoImageToModelRequest( type=TripoTaskType.IMAGE_TO_MODEL, file=tripo_file, model_version=model_version, @@ -223,80 +251,105 @@ async def generate_mesh(self, image, model_version=None, style=None, texture=Non texture_quality=texture_quality, face_limit=face_limit, auto_size=True, - quad=quad + quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoMultiviewToModelNode: +class TripoMultiviewToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API. """ - AVERAGE_DURATION = 80 + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoMultiviewToModelNode", + display_name="Tripo: Multiview to Model", + category="api node/3d/Tripo", + inputs=[ + IO.Image.Input("image"), + IO.Image.Input("image_left", optional=True), + IO.Image.Input("image_back", optional=True), + IO.Image.Input("image_right", optional=True), + IO.Combo.Input( + "model_version", + options=TripoModelVersion, + optional=True, + tooltip="The model version to use for generation", + ), + IO.Combo.Input( + "orientation", + options=TripoOrientation, + default=TripoOrientation.DEFAULT, + optional=True, + ), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - }, - "optional": { - "image_left": ("IMAGE",), - "image_back": ("IMAGE",), - "image_right": ("IMAGE",), - "model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion), - "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "model_seed": ("INT", {"default": 42}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): + async def execute( + cls, + image: torch.Tensor, + image_left: Optional[torch.Tensor] = None, + image_back: Optional[torch.Tensor] = None, + image_right: Optional[torch.Tensor] = None, + model_version: Optional[str] = None, + orientation: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + model_seed: Optional[int] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: if image is None: raise RuntimeError("front image for multiview is required") images = [] - image_dict = { - "image": image, - "image_left": image_left, - "image_back": image_back, - "image_right": image_right - } + image_dict = {"image": image, "image_left": image_left, "image_back": image_back, "image_right": image_right} if image_left is None and image_back is None and image_right is None: raise RuntimeError("At least one of left, back, or right image must be provided for multiview") for image_name in ["image", "image_left", "image_back", "image_right"]: image_ = image_dict[image_name] if image_ is not None: - tripo_file = await upload_image_to_tripo(image_, **kwargs) - images.append(tripo_file) + images.append( + TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image_, max_images=1))[0], type="jpeg" + ) + ) + ) else: images.append(TripoFileEmptyReference()) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoMultiviewToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoMultiviewToModelRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoMultiviewToModelRequest( type=TripoTaskType.MULTIVIEW_TO_MODEL, files=images, model_version=model_version, @@ -310,272 +363,283 @@ async def generate_mesh(self, image, image_left=None, image_back=None, image_rig face_limit=face_limit, quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoTextureNode: +class TripoTextureNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model_task_id": ("MODEL_TASK_ID",), - }, - "optional": { - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 80 - - async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoTextureModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoTextureModelRequest( + def define_schema(cls): + return IO.Schema( + node_id="TripoTextureNode", + display_name="Tripo: Texture model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("MODEL_TASK_ID").Input("model_task_id"), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + async def execute( + cls, + model_task_id, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoTextureModelRequest( original_model_task_id=model_task_id, texture=texture, pbr=pbr, texture_seed=texture_seed, texture_quality=texture_quality, - texture_alignment=texture_alignment + texture_alignment=texture_alignment, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) + +class TripoRefineNode(IO.ComfyNode): -class TripoRefineNode: @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model_task_id": ("MODEL_TASK_ID", { - "tooltip": "Must be a v1.4 Tripo model" - }), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only." - - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 240 - - async def generate_mesh(self, model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoRefineModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoRefineModelRequest( - draft_model_task_id=model_task_id - ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + def define_schema(cls): + return IO.Schema( + node_id="TripoRefineNode", + display_name="Tripo: Refine Draft model", + category="api node/3d/Tripo", + description="Refine a draft model created by v1.4 Tripo models only.", + inputs=[ + IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + async def execute(cls, model_task_id) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoRefineModelRequest(draft_model_task_id=model_task_id), + ) + return await poll_until_finished(cls, response, average_duration=240) + +class TripoRigNode(IO.ComfyNode): -class TripoRigNode: @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("MODEL_TASK_ID",), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("STRING", "RIG_TASK_ID") - RETURN_NAMES = ("model_file", "rig task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 180 - - async def generate_mesh(self, original_model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoAnimateRigRequest, - response_model=TripoTaskResponse, - ), - request=TripoAnimateRigRequest( - original_model_task_id=original_model_task_id, - out_format="glb", - spec="tripo" - ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + def define_schema(cls): + return IO.Schema( + node_id="TripoRigNode", + display_name="Tripo: Rig model", + category="api node/3d/Tripo", + inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + async def execute(cls, original_model_task_id) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoAnimateRigRequest(original_model_task_id=original_model_task_id, out_format="glb", spec="tripo"), + ) + return await poll_until_finished(cls, response, average_duration=180) -class TripoRetargetNode: +class TripoRetargetNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("RIG_TASK_ID",), - "animation": ([ - "preset:idle", - "preset:walk", - "preset:climb", - "preset:jump", - "preset:slash", - "preset:shoot", - "preset:hurt", - "preset:fall", - "preset:turn", - ],), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("STRING", "RETARGET_TASK_ID") - RETURN_NAMES = ("model_file", "retarget task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 30 - - async def generate_mesh(self, animation, original_model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoAnimateRetargetRequest, - response_model=TripoTaskResponse, - ), - request=TripoAnimateRetargetRequest( + def define_schema(cls): + return IO.Schema( + node_id="TripoRetargetNode", + display_name="Tripo: Retarget rigged model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), + IO.Combo.Input( + "animation", + options=[ + "preset:idle", + "preset:walk", + "preset:climb", + "preset:jump", + "preset:slash", + "preset:shoot", + "preset:hurt", + "preset:fall", + "preset:turn", + ], + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + async def execute(cls, original_model_task_id, animation: str) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoAnimateRetargetRequest( original_model_task_id=original_model_task_id, animation=animation, out_format="glb", - bake_animation=True + bake_animation=True, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=30) -class TripoConversionNode: +class TripoConversionNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",), - "format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],), - }, - "optional": { - "quad": ("BOOLEAN", {"default": False}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}), - "texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoConversionNode", + display_name="Tripo: Convert model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), + IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), + IO.Boolean.Input("quad", default=False, optional=True), + IO.Int.Input( + "face_limit", + default=-1, + min=-1, + max=500000, + optional=True, + ), + IO.Int.Input( + "texture_size", + default=4096, + min=128, + max=4096, + optional=True, + ), + IO.Combo.Input( + "texture_format", + options=["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], + default="JPEG", + optional=True, + ), + ], + outputs=[], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) @classmethod - def VALIDATE_INPUTS(cls, input_types): + def validate_inputs(cls, input_types): # The min and max of input1 and input2 are still validated because # we didn't take `input1` or `input2` as arguments if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"): return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type" return True - RETURN_TYPES = () - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 30 - - async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): + @classmethod + async def execute( + cls, + original_model_task_id, + format: str, + quad: bool, + face_limit: int, + texture_size: int, + texture_format: str, + ) -> IO.NodeOutput: if not original_model_task_id: raise RuntimeError("original_model_task_id is required") - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoConvertModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoConvertModelRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoConvertModelRequest( original_model_task_id=original_model_task_id, format=format, quad=quad if quad else None, face_limit=face_limit if face_limit != -1 else None, texture_size=texture_size if texture_size != 4096 else None, - texture_format=texture_format if texture_format != "JPEG" else None + texture_format=texture_format if texture_format != "JPEG" else None, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) - - -NODE_CLASS_MAPPINGS = { - "TripoTextToModelNode": TripoTextToModelNode, - "TripoImageToModelNode": TripoImageToModelNode, - "TripoMultiviewToModelNode": TripoMultiviewToModelNode, - "TripoTextureNode": TripoTextureNode, - "TripoRefineNode": TripoRefineNode, - "TripoRigNode": TripoRigNode, - "TripoRetargetNode": TripoRetargetNode, - "TripoConversionNode": TripoConversionNode, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "TripoTextToModelNode": "Tripo: Text to Model", - "TripoImageToModelNode": "Tripo: Image to Model", - "TripoMultiviewToModelNode": "Tripo: Multiview to Model", - "TripoTextureNode": "Tripo: Texture model", - "TripoRefineNode": "Tripo: Refine Draft model", - "TripoRigNode": "Tripo: Rig model", - "TripoRetargetNode": "Tripo: Retarget rigged model", - "TripoConversionNode": "Tripo: Convert model", -} + ) + return await poll_until_finished(cls, response, average_duration=30) + + +class TripoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TripoTextToModelNode, + TripoImageToModelNode, + TripoMultiviewToModelNode, + TripoTextureNode, + TripoRefineNode, + TripoRigNode, + TripoRetargetNode, + TripoConversionNode, + ] + + +async def comfy_entrypoint() -> TripoExtension: + return TripoExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index c2ec391aadd4..ab96760cbce1 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -20,6 +20,7 @@ trim_video, ) from .download_helpers import ( + download_url_as_bytesio, download_url_to_bytesio, download_url_to_image_tensor, download_url_to_video_output, @@ -56,6 +57,7 @@ "upload_images_to_comfyapi", "upload_video_to_comfyapi", # Download helpers + "download_url_as_bytesio", "download_url_to_bytesio", "download_url_to_image_tensor", "download_url_to_video_output", diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 055e690de4e5..791dd5a5027b 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -240,6 +240,18 @@ async def download_url_to_video_output( return VideoFromFile(result) +async def download_url_as_bytesio( + url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> BytesIO: + """Downloads content from a URL and returns a new BytesIO (rewound to 0).""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return result + + def _generate_operation_id(method: str, url: str, attempt: int) -> str: try: parsed = urlparse(url) From 426cde37f10dc391f9601ab938e02c0faa42db14 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:56:51 -0700 Subject: [PATCH 30/49] Remove useless function (#10472) --- comfy/model_management.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 79d6ff9d441d..cf015a29ac5b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -998,12 +998,6 @@ def device_supports_non_blocking(device): return False return True -def device_should_use_non_blocking(device): - if not device_supports_non_blocking(device): - return False - return False - # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others - def force_channels_last(): if args.force_channels_last: return True From e86b79ab9ea7e740b80490353f3f5763840ede81 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 26 Oct 2025 00:35:30 +0300 Subject: [PATCH 31/49] convert Gemini API nodes to V3 schema (#10476) --- comfy_api_nodes/apinode_utils.py | 26 -- comfy_api_nodes/nodes_gemini.py | 607 +++++++++++----------------- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/conversions.py | 25 ++ 4 files changed, 271 insertions(+), 389 deletions(-) diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index e3d2820592cd..4182c8f80b7f 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -3,8 +3,6 @@ import mimetypes from typing import Optional, Union from comfy.utils import common_upscale -from comfy_api.util import VideoContainer, VideoCodec -from comfy_api.input.video_types import VideoInput from comfy_api_nodes.apis.client import ( ApiClient, ApiEndpoint, @@ -209,30 +207,6 @@ async def upload_file_to_comfyapi( return response.download_url -def video_to_base64_string( - video: VideoInput, - container_format: VideoContainer = None, - codec: VideoCodec = None -) -> str: - """ - Converts a video input to a base64 string. - - Args: - video: The video input to convert - container_format: Optional container format to use (defaults to video.container if available) - codec: Optional codec to use (defaults to video.codec if available) - """ - video_bytes_io = BytesIO() - - # Use provided format/codec if specified, otherwise use video's own if available - format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) - codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) - - video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) - video_bytes_io.seek(0) - return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") - - async def upload_images_to_comfyapi( image: torch.Tensor, max_images=8, diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ca11b67ed192..67f2469ad36e 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -2,42 +2,47 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference """ + from __future__ import annotations +import base64 import json -import time import os +import time import uuid -import base64 -from io import BytesIO from enum import Enum -from typing import Optional, Literal +from io import BytesIO +from typing import Literal, Optional import torch +from typing_extensions import override import folder_paths -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from server import PromptServer +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api.util import VideoCodec, VideoContainer from comfy_api_nodes.apis import ( GeminiContent, GeminiGenerateContentRequest, GeminiGenerateContentResponse, GeminiInlineData, - GeminiPart, GeminiMimeType, + GeminiPart, ) -from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, +from comfy_api_nodes.apis.gemini_api import ( + GeminiImageConfig, + GeminiImageGenerateContentRequest, + GeminiImageGenerationConfig, ) -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( + ApiEndpoint, + audio_to_base64_string, + bytesio_to_image_tensor, + sync_op, + tensor_to_base64_string, + validate_string, video_to_base64_string, ) -from comfy_api_nodes.util import validate_string, tensor_to_base64_string, bytesio_to_image_tensor, audio_to_base64_string -from comfy_api.util import VideoContainer, VideoCodec - +from server import PromptServer GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB @@ -63,50 +68,6 @@ class GeminiImageModel(str, Enum): gemini_2_5_flash_image = "gemini-2.5-flash-image" -def get_gemini_endpoint( - model: GeminiModel, -) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: - """ - Get the API endpoint for a given Gemini model. - - Args: - model: The Gemini model to use, either as enum or string value. - - Returns: - ApiEndpoint configured for the specific Gemini model. - """ - if isinstance(model, str): - model = GeminiModel(model) - return ApiEndpoint( - path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", - method=HttpMethod.POST, - request_model=GeminiGenerateContentRequest, - response_model=GeminiGenerateContentResponse, - ) - - -def get_gemini_image_endpoint( - model: GeminiImageModel, -) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: - """ - Get the API endpoint for a given Gemini model. - - Args: - model: The Gemini model to use, either as enum or string value. - - Returns: - ApiEndpoint configured for the specific Gemini model. - """ - if isinstance(model, str): - model = GeminiImageModel(model) - return ApiEndpoint( - path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", - method=HttpMethod.POST, - request_model=GeminiImageGenerateContentRequest, - response_model=GeminiGenerateContentResponse, - ) - - def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: """ Convert image tensor input to Gemini API compatible parts. @@ -119,9 +80,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: """ image_parts: list[GeminiPart] = [] for image_index in range(image_input.shape[0]): - image_as_b64 = tensor_to_base64_string( - image_input[image_index].unsqueeze(0) - ) + image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0)) image_parts.append( GeminiPart( inlineData=GeminiInlineData( @@ -133,37 +92,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: return image_parts -def create_text_part(text: str) -> GeminiPart: - """ - Create a text part for the Gemini API request. - - Args: - text: The text content to include in the request. - - Returns: - A GeminiPart object with the text content. - """ - return GeminiPart(text=text) - - -def get_parts_from_response( - response: GeminiGenerateContentResponse -) -> list[GeminiPart]: - """ - Extract all parts from the Gemini API response. - - Args: - response: The API response from Gemini. - - Returns: - List of response parts from the first candidate. - """ - return response.candidates[0].content.parts - - -def get_parts_by_type( - response: GeminiGenerateContentResponse, part_type: Literal["text"] | str -) -> list[GeminiPart]: +def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: """ Filter response parts by their type. @@ -175,14 +104,10 @@ def get_parts_by_type( List of response parts matching the requested type. """ parts = [] - for part in get_parts_from_response(response): + for part in response.candidates[0].content.parts: if part_type == "text" and hasattr(part, "text") and part.text: parts.append(part) - elif ( - hasattr(part, "inlineData") - and part.inlineData - and part.inlineData.mimeType == part_type - ): + elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -210,11 +135,11 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te returned_image = bytesio_to_image_tensor(BytesIO(image_data)) image_tensors.append(returned_image) if len(image_tensors) == 0: - return torch.zeros((1,1024,1024,4)) + return torch.zeros((1, 1024, 1024, 4)) return torch.cat(image_tensors, dim=0) -class GeminiNode(ComfyNodeABC): +class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -225,96 +150,79 @@ class GeminiNode(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.", - }, + def define_schema(cls): + return IO.Schema( + node_id="GeminiNode", + display_name="Google Gemini", + category="api node/text/Gemini", + description="Generate text responses with Google's Gemini AI model. " + "You can provide multiple types of inputs (text, images, audio, video) " + "as context for generating more relevant and meaningful responses.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text inputs to the model, used to generate a response. " + "You can include detailed instructions, questions, or context for the model.", ), - "model": ( - IO.COMBO, - { - "tooltip": "The Gemini model to use for generating responses.", - "options": [model.value for model in GeminiModel], - "default": GeminiModel.gemini_2_5_pro.value, - }, + IO.Combo.Input( + "model", + options=GeminiModel, + default=GeminiModel.gemini_2_5_pro, + tooltip="The Gemini model to use for generating responses.", ), - "seed": ( - IO.INT, - { - "default": 42, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", - }, + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, + IO.Image.Input( + "images", + optional=True, + tooltip="Optional image(s) to use as context for the model. " + "To include multiple images, you can use the Batch Images node.", ), - "audio": ( - IO.AUDIO, - { - "tooltip": "Optional audio to use as context for the model.", - "default": None, - }, + IO.Audio.Input( + "audio", + optional=True, + tooltip="Optional audio to use as context for the model.", ), - "video": ( - IO.VIDEO, - { - "tooltip": "Optional video to use as context for the model.", - "default": None, - }, + IO.Video.Input( + "video", + optional=True, + tooltip="Optional video to use as context for the model.", ), - "files": ( - "GEMINI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", - }, + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses." - RETURN_TYPES = ("STRING",) - FUNCTION = "api_call" - CATEGORY = "api node/text/Gemini" - API_NODE = True - - def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]: - """ - Convert video input to Gemini API compatible parts. - - Args: - video_input: Video tensor from ComfyUI. - **kwargs: Additional arguments to pass to the conversion function. + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - Returns: - List of GeminiPart objects containing the encoded video. - """ + @classmethod + def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: + """Convert video input to Gemini API compatible parts.""" - base_64_string = video_to_base64_string( - video_input, - container_format=VideoContainer.MP4, - codec=VideoCodec.H264 - ) + base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264) return [ GeminiPart( inlineData=GeminiInlineData( @@ -324,7 +232,8 @@ def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart ) ] - def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]: + @classmethod + def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]: """ Convert audio input to Gemini API compatible parts. @@ -337,10 +246,10 @@ def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]: audio_parts: list[GeminiPart] = [] for batch_index in range(audio_input["waveform"].shape[0]): # Recreate an IO.AUDIO object for the given batch dimension index - audio_at_index = { - "waveform": audio_input["waveform"][batch_index].unsqueeze(0), - "sample_rate": audio_input["sample_rate"], - } + audio_at_index = Input.Audio( + waveform=audio_input["waveform"][batch_index].unsqueeze(0), + sample_rate=audio_input["sample_rate"], + ) # Convert to MP3 format for compatibility with Gemini API audio_bytes = audio_to_base64_string( audio_at_index, @@ -357,38 +266,38 @@ def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]: ) return audio_parts - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, - model: GeminiModel, - images: Optional[IO.IMAGE] = None, - audio: Optional[IO.AUDIO] = None, - video: Optional[IO.VIDEO] = None, + model: str, + seed: int, + images: Optional[torch.Tensor] = None, + audio: Optional[Input.Audio] = None, + video: Optional[Input.Video] = None, files: Optional[list[GeminiPart]] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[str]: - # Validate inputs + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) # Create parts list with text prompt as the first part - parts: list[GeminiPart] = [create_text_part(prompt)] + parts: list[GeminiPart] = [GeminiPart(text=prompt)] # Add other modal parts if images is not None: image_parts = create_image_parts(images) parts.extend(image_parts) if audio is not None: - parts.extend(self.create_audio_parts(audio)) + parts.extend(cls.create_audio_parts(audio)) if video is not None: - parts.extend(self.create_video_parts(video)) + parts.extend(cls.create_video_parts(video)) if files is not None: parts.extend(files) # Create response - response = await SynchronousOperation( - endpoint=get_gemini_endpoint(model), - request=GeminiGenerateContentRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiGenerateContentRequest( contents=[ GeminiContent( role="user", @@ -396,15 +305,15 @@ async def api_call( ) ] ), - auth_kwargs=kwargs, - ).execute() + response_model=GeminiGenerateContentResponse, + ) # Get result output output_text = get_text_from_response(response) - if unique_id and output_text: + if output_text: # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. render_spec = { - "node_id": unique_id, + "node_id": cls.hidden.unique_id, "component": "ChatHistoryWidget", "props": { "history": json.dumps( @@ -424,10 +333,10 @@ async def api_call( render_spec, ) - return (output_text or "Empty response from Gemini model...",) + return IO.NodeOutput(output_text or "Empty response from Gemini model...") -class GeminiInputFiles(ComfyNodeABC): +class GeminiInputFiles(IO.ComfyNode): """ Loads and formats input files for use with the Gemini API. @@ -438,7 +347,7 @@ class GeminiInputFiles(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: + def define_schema(cls): """ For details about the supported file input types, see: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference @@ -453,39 +362,37 @@ def INPUT_TYPES(cls) -> InputTypeDict: ] input_files = sorted(input_files, key=lambda x: x.name) input_files = [f.name for f in input_files] - return { - "required": { - "file": ( - IO.COMBO, - { - "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", - "options": input_files, - "default": input_files[0] if input_files else None, - }, + return IO.Schema( + node_id="GeminiInputFiles", + display_name="Gemini Input Files", + category="api node/text/Gemini", + description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " + "The files will be read by the Gemini model when generating a response. " + "The contents of the text file count toward the token limit. " + "🛈 TIP: Can be chained together with other Gemini Input File nodes.", + inputs=[ + IO.Combo.Input( + "file", + options=input_files, + default=input_files[0] if input_files else None, + tooltip="Input files to include as context for the model. " + "Only accepts text (.txt) and PDF (.pdf) files for now.", ), - }, - "optional": { - "GEMINI_INPUT_FILES": ( + IO.Custom("GEMINI_INPUT_FILES").Input( "GEMINI_INPUT_FILES", - { - "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", - "default": None, - }, + optional=True, + tooltip="An optional additional file(s) to batch together with the file loaded from this node. " + "Allows chaining of input files so that a single message can include multiple input files.", ), - }, - } - - DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes." - RETURN_TYPES = ("GEMINI_INPUT_FILES",) - FUNCTION = "prepare_files" - CATEGORY = "api node/text/Gemini" - - def create_file_part(self, file_path: str) -> GeminiPart: - mime_type = ( - GeminiMimeType.application_pdf - if file_path.endswith(".pdf") - else GeminiMimeType.text_plain + ], + outputs=[ + IO.Custom("GEMINI_INPUT_FILES").Output(), + ], ) + + @classmethod + def create_file_part(cls, file_path: str) -> GeminiPart: + mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain # Use base64 string directly, not the data URI with open(file_path, "rb") as f: file_content = f.read() @@ -498,120 +405,95 @@ def create_file_part(self, file_path: str) -> GeminiPart: ) ) - def prepare_files( - self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = [] - ) -> tuple[list[GeminiPart]]: - """ - Loads and formats input files for Gemini API. - """ + @classmethod + def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput: + """Loads and formats input files for Gemini API.""" + if GEMINI_INPUT_FILES is None: + GEMINI_INPUT_FILES = [] file_path = folder_paths.get_annotated_filepath(file) - input_file_content = self.create_file_part(file_path) - files = [input_file_content] + GEMINI_INPUT_FILES - return (files,) + input_file_content = cls.create_file_part(file_path) + return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES) -class GeminiImage(ComfyNodeABC): - """ - Node to generate text and image responses from a Gemini model. +class GeminiImage(IO.ComfyNode): - This node allows users to interact with Google's Gemini AI models, providing - multimodal inputs (text, images, files) to generate coherent - text and image responses. The node works with the latest Gemini models, handling the - API communication and response parsing. - """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for generation", - }, + def define_schema(cls): + return IO.Schema( + node_id="GeminiImageNode", + display_name="Google Gemini Image", + category="api node/image/Gemini", + description="Edit images synchronously via Google API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt for generation", + default="", ), - "model": ( - IO.COMBO, - { - "tooltip": "The Gemini model to use for generating responses.", - "options": [model.value for model in GeminiImageModel], - "default": GeminiImageModel.gemini_2_5_flash_image.value, - }, + IO.Combo.Input( + "model", + options=GeminiImageModel, + default=GeminiImageModel.gemini_2_5_flash_image, + tooltip="The Gemini model to use for generating responses.", ), - "seed": ( - IO.INT, - { - "default": 42, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", - }, + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, + IO.Image.Input( + "images", + optional=True, + tooltip="Optional image(s) to use as context for the model. " + "To include multiple images, you can use the Batch Images node.", ), - "files": ( - "GEMINI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", - }, + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", ), - # TODO: later we can add this parameter later - # "n": ( - # IO.INT, - # { - # "default": 1, - # "min": 1, - # "max": 8, - # "step": 1, - # "display": "number", - # "tooltip": "How many images to generate", - # }, - # ), - "aspect_ratio": ( - IO.COMBO, - { - "tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.", - "options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], - "default": "auto", - }, + IO.Combo.Input( + "aspect_ratio", + options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + default="auto", + tooltip="Defaults to matching the output image size to that of your input image, " + "or otherwise generates 1:1 squares.", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE, IO.STRING) - FUNCTION = "api_call" - CATEGORY = "api node/image/Gemini" - DESCRIPTION = "Edit images synchronously via Google API." - API_NODE = True - - async def api_call( - self, + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, prompt: str, - model: GeminiImageModel, - images: Optional[IO.IMAGE] = None, + model: str, + seed: int, + images: Optional[torch.Tensor] = None, files: Optional[list[GeminiPart]] = None, - n=1, aspect_ratio: str = "auto", - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) - parts: list[GeminiPart] = [create_text_part(prompt)] + parts: list[GeminiPart] = [GeminiPart(text=prompt)] if not aspect_ratio: aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December @@ -623,29 +505,27 @@ async def api_call( if files is not None: parts.extend(files) - response = await SynchronousOperation( - endpoint=get_gemini_image_endpoint(model), - request=GeminiImageGenerateContentRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( contents=[ - GeminiContent( - role="user", - parts=parts, - ), + GeminiContent(role="user", parts=parts), ], generationConfig=GeminiImageGenerationConfig( - responseModalities=["TEXT","IMAGE"], + responseModalities=["TEXT", "IMAGE"], imageConfig=None if aspect_ratio == "auto" else image_config, - ) + ), ), - auth_kwargs=kwargs, - ).execute() + response_model=GeminiGenerateContentResponse, + ) output_image = get_image_from_response(response) output_text = get_text_from_response(response) - if unique_id and output_text: + if output_text: # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. render_spec = { - "node_id": unique_id, + "node_id": cls.hidden.unique_id, "component": "ChatHistoryWidget", "props": { "history": json.dumps( @@ -666,17 +546,18 @@ async def api_call( ) output_text = output_text or "Empty response from Gemini model..." - return (output_image, output_text,) + return IO.NodeOutput(output_image, output_text) + +class GeminiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + GeminiNode, + GeminiImage, + GeminiInputFiles, + ] -NODE_CLASS_MAPPINGS = { - "GeminiNode": GeminiNode, - "GeminiImageNode": GeminiImage, - "GeminiInputFiles": GeminiInputFiles, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "GeminiNode": "Google Gemini", - "GeminiImageNode": "Google Gemini Image", - "GeminiInputFiles": "Gemini Input Files", -} +async def comfy_entrypoint() -> GeminiExtension: + return GeminiExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index ab96760cbce1..0cca2b59bb66 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -18,6 +18,7 @@ tensor_to_bytesio, tensor_to_pil, trim_video, + video_to_base64_string, ) from .download_helpers import ( download_url_as_bytesio, @@ -73,6 +74,7 @@ "tensor_to_bytesio", "tensor_to_pil", "trim_video", + "video_to_base64_string", # Validation utilities "get_number_of_images", "validate_aspect_ratio_closeness", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 10cd1051b4d7..9f4c90c5cc3b 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -12,6 +12,7 @@ from comfy.utils import common_upscale from comfy_api.latest import Input, InputImpl +from comfy_api.util import VideoContainer, VideoCodec from ._helpers import mimetype_to_extension @@ -173,6 +174,30 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co return base64.b64encode(audio_bytes).decode("utf-8") +def video_to_base64_string( + video: Input.Video, + container_format: VideoContainer = None, + codec: VideoCodec = None +) -> str: + """ + Converts a video input to a base64 string. + + Args: + video: The video input to convert + container_format: Optional container format to use (defaults to video.container if available) + codec: Optional codec to use (defaults to video.codec if available) + """ + video_bytes_io = BytesIO() + + # Use provided format/codec if specified, otherwise use video's own if available + format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) + codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) + + video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) + video_bytes_io.seek(0) + return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") + + def audio_ndarray_to_bytesio( audio_data_np: np.ndarray, sample_rate: int, From 098a352f136c610071bcb74f13e5b0ca16e6e7b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 25 Oct 2025 17:05:22 -0700 Subject: [PATCH 32/49] Add warning for torch-directml usage (#10482) Added a warning message about the state of torch-directml. --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index cf015a29ac5b..afe78f36ec0e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -89,6 +89,7 @@ def get_supported_float8_types(): directml_enabled = False if args.directml is not None: + logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.") import torch_directml directml_enabled = True device_index = args.directml From f6bbc1ac846b7d9a73ae50c3a45cf5a41058c54d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 25 Oct 2025 20:07:29 -0700 Subject: [PATCH 33/49] Fix mistake. (#10484) --- comfy/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sample.py b/comfy/sample.py index b1395da84ae0..2f8f3a51c5fc 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -17,7 +17,7 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None): if i in unique_inds: noises.append(noise) noises = [noises[i] for i in inverse] - noises = torch.cat(noises, axis=0) + return torch.cat(noises, axis=0) def prepare_noise(latent_image, seed, noise_inds=None): """ From 9d529e53084bdec28f684f3886a26c93598e7338 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 26 Oct 2025 08:51:06 +0200 Subject: [PATCH 34/49] fix(api-nodes): random issues on Windows by capturing general OSError for retries (#10486) --- comfy_api_nodes/util/client.py | 15 +++++---------- comfy_api_nodes/util/download_helpers.py | 6 +++--- comfy_api_nodes/util/upload_helpers.py | 4 ++-- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 5833b118fdf6..9c036d64b55d 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -2,7 +2,6 @@ import contextlib import json import logging -import socket import time import uuid from dataclasses import dataclass @@ -456,24 +455,20 @@ async def _diagnose_connectivity() -> dict[str, bool]: results = { "internet_accessible": False, "api_accessible": False, - "is_local_issue": False, - "is_api_issue": False, } timeout = aiohttp.ClientTimeout(total=5.0) async with aiohttp.ClientSession(timeout=timeout) as session: - try: + with contextlib.suppress(ClientError, OSError): async with session.get("https://www.google.com") as resp: results["internet_accessible"] = resp.status < 500 - except (ClientError, asyncio.TimeoutError, socket.gaierror): - results["is_local_issue"] = True + if not results["internet_accessible"]: return results parsed = urlparse(default_base_url()) health_url = f"{parsed.scheme}://{parsed.netloc}/health" - with contextlib.suppress(ClientError, asyncio.TimeoutError): + with contextlib.suppress(ClientError, OSError): async with session.get(health_url) as resp: results["api_accessible"] = resp.status < 500 - results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] return results @@ -790,7 +785,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): except ProcessingInterrupted: logging.debug("Polling was interrupted by user") raise - except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + except (ClientError, OSError) as e: if attempt <= cfg.max_retries: logging.warning( "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", @@ -824,7 +819,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): delay *= cfg.retry_backoff continue diag = await _diagnose_connectivity() - if diag.get("is_local_issue"): + if not diag["internet_accessible"]: try: request_logger.log_request_response( operation_id=operation_id, diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 791dd5a5027b..f89045e12ad4 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -32,7 +32,7 @@ async def download_url_to_bytesio( dest: Optional[Union[BytesIO, IO[bytes], str, Path]], *, timeout: Optional[float] = None, - max_retries: int = 3, + max_retries: int = 5, retry_delay: float = 1.0, retry_backoff: float = 2.0, cls: type[COMFY_IO.ComfyNode] = None, @@ -177,7 +177,7 @@ async def _monitor(): return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None - except (ClientError, asyncio.TimeoutError) as e: + except (ClientError, OSError) as e: if attempt <= max_retries: with contextlib.suppress(Exception): request_logger.log_request_response( @@ -191,7 +191,7 @@ async def _monitor(): continue diag = await _diagnose_connectivity() - if diag.get("is_local_issue"): + if not diag["internet_accessible"]: raise LocalNetworkError( "Unable to connect to the network. Please check your internet connection and try again." ) from e diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index a345d451d4cd..7bfc61704a2a 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -290,7 +290,7 @@ async def _monitor(): return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None - except (aiohttp.ClientError, asyncio.TimeoutError) as e: + except (aiohttp.ClientError, OSError) as e: if attempt <= max_retries: with contextlib.suppress(Exception): request_logger.log_request_response( @@ -313,7 +313,7 @@ async def _monitor(): continue diag = await _diagnose_connectivity() - if diag.get("is_local_issue"): + if not diag["internet_accessible"]: raise LocalNetworkError( "Unable to connect to the network. Please check your internet connection and try again." ) from e From c170fd2db598a0bdce56f80e22e83e10ad731421 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 26 Oct 2025 17:23:01 -0700 Subject: [PATCH 35/49] Bump portable deps workflow to torch cu130 python 3.13.9 (#10493) --- .github/workflows/windows_release_dependencies.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index f1e2946e66b0..f61ee21a230e 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,7 +17,7 @@ on: description: 'cuda version' required: true type: string - default: "129" + default: "130" python_minor: description: 'python minor version' @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "6" + default: "9" # push: # branches: # - master From efb35035f3531c6328a6eeff90fc88e873baa437 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 07:55:44 +0100 Subject: [PATCH 36/49] Remove CK reference and ensure correct compute dtype --- comfy/model_detection.py | 12 ++---------- comfy/ops.py | 4 +--- comfy/quant_ops.py | 26 ++++---------------------- comfy/sd.py | 5 ++++- 4 files changed, 11 insertions(+), 36 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ffb1885fd1b4..335ccbd17ee4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -7,8 +7,7 @@ import torch -def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."): - # 1. Check for per-layer config in metadata +def detect_layer_quantization(metadata): quant_key = "_quantization_metadata" if metadata is not None and quant_key in metadata: quant_metadata = metadata.pop(quant_key) @@ -18,13 +17,6 @@ def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_mode return quant_metadata["layers"] else: raise ValueError(f"Invalid quantization metadata format") - - # 2. Check for legacy scaled_fp8 marker - scaled_fp8_key = f"{prefix}scaled_fp8" - if scaled_fp8_key in state_dict: - logging.debug("Detected legacy scaled_fp8 format, using legacy code path") - return None - return None @@ -724,7 +716,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.optimizations["fp8"] = True # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix) + layer_quant_config = detect_layer_quantization(metadata) if layer_quant_config: model_config.layer_quant_config = layer_quant_config logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") diff --git a/comfy/ops.py b/comfy/ops.py index 8d11aeefc0af..5edd4daa2bfb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -561,11 +561,9 @@ def forward(self, input, *args, **kwargs): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): - # If model_config.layer_quant_config exists, use new MixedPrecisionOps. if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config - # MixedPrecisionOps._compute_dtype = compute_dtype # TODO - MixedPrecisionOps._compute_dtype = torch.bfloat16 + MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 3802da8524e7..8d7f6480a31c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -75,21 +75,6 @@ def _copy_layout_params(params): return new_params -try: - import comfy_kitchen as ck - ck.disable_backend("cutile") - _CK_AVAILABLE = True - logging.info("comfy_kitchen available for optimized quantization kernels") -except ImportError: - ck = None - _CK_AVAILABLE = False - logging.info("comfy_kitchen not available - using PyTorch fallbacks") -except Exception as e: - ck = None - _CK_AVAILABLE = False - logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") - - class QuantizedLayout: """ Base class for quantization layouts. @@ -372,13 +357,10 @@ def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - if _CK_AVAILABLE and tensor.device.type == "cuda": - qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) - else: - lp_amax = torch.finfo(fp8_dtype).max - tensor_scaled = tensor.float() / scale - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) layout_params = { 'scale': scale, diff --git a/comfy/sd.py b/comfy/sd.py index b965e98427d9..6411bb27d62e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if hasattr(model_config, "layer_quant_config"): + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) if model_options.get("fp8_optimizations", False): From a7216e18e5cf40d0dcbadd2f4e4c03a0c3f38f49 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 08:41:23 +0100 Subject: [PATCH 37/49] Update unit tests --- comfy/ops.py | 2 +- comfy/quant_ops.py | 4 +- .../test_mixed_precision.py | 147 +++---- tests-unit/comfy_quant/test_quant_registry.py | 183 ++++++++ tests-unit/comfy_test/test_quant_detection.py | 262 ------------ tests-unit/comfy_test/test_quant_registry.py | 399 ------------------ 6 files changed, 235 insertions(+), 762 deletions(-) rename tests-unit/{comfy_test => comfy_quant}/test_mixed_precision.py (60%) create mode 100644 tests-unit/comfy_quant/test_quant_registry.py delete mode 100644 tests-unit/comfy_test/test_quant_detection.py delete mode 100644 tests-unit/comfy_test/test_quant_registry.py diff --git a/comfy/ops.py b/comfy/ops.py index 5edd4daa2bfb..8af1e949dcd2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -546,7 +546,7 @@ def _forward(self, input, weight, bias): def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) - self._forward(input, weight, bias) + return self._forward(input, weight, bias) def forward(self, input, *args, **kwargs): run_every_op() diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8d7f6480a31c..96d2fa03fdbd 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -143,7 +143,7 @@ def __tensor_flatten__(self): """ Tensor flattening protocol for proper device movement. """ - inner_tensors = ["_q_data"] + inner_tensors = ["_qdata"] ctx = { "layout_type": self._layout_type, } @@ -206,7 +206,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # Step 3: Fallback to dequantization if isinstance(args[0] if args else None, QuantizedTensor): - logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) @classmethod diff --git a/tests-unit/comfy_test/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py similarity index 60% rename from tests-unit/comfy_test/test_mixed_precision.py rename to tests-unit/comfy_quant/test_mixed_precision.py index cbfa2866da4d..e3455276063d 100644 --- a/tests-unit/comfy_test/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -1,8 +1,3 @@ -""" -End-to-end tests for mixed precision quantization. -Tests Phase 3: Mixed Precision Operations -""" - import unittest import torch import sys @@ -12,10 +7,10 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from comfy import ops +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout class SimpleModel(torch.nn.Module): - """Simple model for testing mixed precision""" def __init__(self, operations=ops.disable_weight_init): super().__init__() self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) @@ -32,8 +27,7 @@ def forward(self, x): class TestMixedPrecisionOps(unittest.TestCase): - """Test MixedPrecisionOps end-to-end""" - + def test_all_layers_standard(self): """Test that model with no quantization works normally""" # Configure no quantization @@ -67,48 +61,54 @@ def test_mixed_precision_load(self): # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} # Disable for CPU testing + "format": "float8_e4m3fn", + "params": {} }, "layer3": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict = { # Layer 1: FP8 E4M3FN "layer1.weight": fp8_weight1, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), # Layer 2: Standard BF16 "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), - # Layer 3: FP8 E5M2 + # Layer 3: FP8 E4M3FN "layer3.weight": fp8_weight3, "layer3.bias": torch.randn(40, dtype=torch.bfloat16), - "layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } - # Create model and load state dict + # Create model and load state dict (strict=False because custom loading pops keys) model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) + + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) - # Verify handlers are set up correctly - self.assertIsNotNone(model.layer1.quant_handler) - self.assertIsNone(model.layer2.quant_handler) # No quantization - self.assertIsNotNone(model.layer3.quant_handler) + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) # Verify scales were loaded - self.assertEqual(model.layer1.scale_weight.item(), 2.0) - self.assertEqual(model.layer3.scale_weight.item(), 1.5) + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -116,13 +116,13 @@ def test_mixed_precision_load(self): self.assertEqual(output.shape, (5, 40)) - def test_state_dict_round_trip(self): - """Test saving and loading state dict preserves quantization""" + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" # Configure mixed precision layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config @@ -132,45 +132,35 @@ def test_state_dict_round_trip(self): state_dict1 = { "layer1.weight": fp8_weight, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model1 = SimpleModel(operations=ops.MixedPrecisionOps) - model1.load_state_dict(state_dict1) + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict1, strict=False) # Save state dict - state_dict2 = model1.state_dict() - - # Verify scale_weight is saved - self.assertIn("layer1.scale_weight", state_dict2) - self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0) - - # Load into new model - model2 = SimpleModel(operations=ops.MixedPrecisionOps) - model2.load_state_dict(state_dict2) - - # Verify handler is set up - self.assertIsNotNone(model2.layer1.quant_handler) - self.assertEqual(model2.layer1.scale_weight.item(), 3.0) + state_dict2 = model.state_dict() - # Verify forward passes match - input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output1 = model1(input_tensor) - output2 = model2(input_tensor) + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) - torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3) + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) def test_weight_function_compatibility(self): """Test that weight_function (LoRA) works with quantized layers""" # Configure FP8 quantization layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config @@ -180,7 +170,7 @@ def test_weight_function_compatibility(self): state_dict = { "layer1.weight": fp8_weight, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), @@ -188,25 +178,24 @@ def test_weight_function_compatibility(self): } model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) - # LoRA delta must match weight shape (20, 10) + # This should trigger dequantization during forward pass def apply_lora(weight): - # Generate LoRA delta matching weight shape lora_delta = torch.randn_like(weight) * 0.01 return weight + lora_delta model.layer1.weight_function.append(apply_lora) - # Forward pass should work with LoRA + # Forward pass should work with LoRA (triggers weight_function path) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) self.assertEqual(output.shape, (5, 40)) def test_error_handling_unknown_format(self): - """Test that unknown formats fall back gracefully""" + """Test that unknown formats raise error""" # Configure with unknown format layer_quant_config = { "layer1": { @@ -226,48 +215,10 @@ def test_error_handling_unknown_format(self): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - # Load should not crash, just log warning + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) - - # Handler should be None (fallback to standard) - self.assertIsNone(model.layer1.quant_handler) - - # Forward pass should still work - input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) - self.assertEqual(output.shape, (5, 40)) - - -class TestPickOperationsWithMixedPrecision(unittest.TestCase): - """Test pick_operations with mixed precision config""" - - def test_pick_operations_with_layer_quant_config(self): - """Test that pick_operations returns MixedPrecisionOps when config present""" - from comfy import supported_models_base - - # Create model config with layer_quant_config - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = { - "layer1": {"format": "fp8_e4m3fn_scaled", "params": {}} - } - - result = ops.pick_operations(None, None, model_config=model_config) - - self.assertEqual(result, ops.MixedPrecisionOps) - self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config) - - def test_pick_operations_without_layer_quant_config(self): - """Test that pick_operations falls back to standard when no config""" - from comfy import supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = None - - result = ops.pick_operations(None, None, model_config=model_config) - - self.assertEqual(result, ops.disable_weight_init) - + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) if __name__ == "__main__": unittest.main() diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000000..263581417177 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,183 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + TensorCoreFP8Layout, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + a_q = QuantizedTensor.from_float( + a_fp32, + TensorCoreFP8Layout, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests-unit/comfy_test/test_quant_detection.py b/tests-unit/comfy_test/test_quant_detection.py deleted file mode 100644 index bb952a81b3b2..000000000000 --- a/tests-unit/comfy_test/test_quant_detection.py +++ /dev/null @@ -1,262 +0,0 @@ -""" -Integration tests for quantization detection. -Tests Phase 2: Detection & Integration -""" - -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from comfy import model_detection - - -class TestNormalizeLayerName(unittest.TestCase): - """Test the normalize_layer_name helper function""" - - def test_strip_prefix_and_suffix(self): - """Test stripping prefix and suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.weight", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_strip_multiple_prefixes(self): - """Test with multiple known prefixes""" - known_prefixes = ["model.diffusion_model.", "model.model.", "net."] - - result1 = model_detection.normalize_layer_name( - "model.diffusion_model.block.attn.weight", - known_prefixes - ) - self.assertEqual(result1, "block.attn") - - result2 = model_detection.normalize_layer_name( - "model.model.encoder.layer.weight", - known_prefixes - ) - self.assertEqual(result2, "encoder.layer") - - result3 = model_detection.normalize_layer_name( - "net.transformer.blocks.0.weight", - known_prefixes - ) - self.assertEqual(result3, "transformer.blocks.0") - - def test_strip_scale_weight_suffix(self): - """Test stripping scale_weight suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.scale_weight", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_strip_bias_suffix(self): - """Test stripping bias suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.bias", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_no_prefix_match(self): - """Test with no prefix match""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "other.model.layer1.weight", - known_prefixes - ) - # Should strip suffix but not prefix - self.assertEqual(result, "other.model.layer1") - - -class TestDetectLayerQuantization(unittest.TestCase): - """Test the detect_layer_quantization function""" - - def test_no_quantization(self): - """Test with no quantization markers""" - state_dict = { - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - "model.diffusion_model.layer2.weight": torch.randn(20, 30), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - self.assertIsNone(result) - - def test_legacy_scaled_fp8(self): - """Test that legacy scaled_fp8 marker returns None""" - # Create FP8 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - "model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn), - "model.diffusion_model.layer1.weight": fp8_weight, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - # Should return None to trigger legacy path - self.assertIsNone(result) - - def test_metadata_format(self): - """Test with new metadata format""" - metadata = { - "format_version": "1.0", - "layers": { - "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": True} - }, - "layer2": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": True} - } - } - } - state_dict = { - "model.diffusion_model._quantization_metadata": metadata, - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertIn("layer2", result) - self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") - self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled") - # Metadata should be popped from state_dict - self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict) - - def test_mixed_precision_detection(self): - """Test detection of mixed precision via scale patterns""" - # Create FP8 tensors by converting from float32 - fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - # Layer 1: FP8 (has scale_weight) - "model.diffusion_model.layer1.weight": fp8_weight1, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - # Layer 2: Standard (no scale_weight) - "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - # Layer 3: FP8 (has scale_weight) - "model.diffusion_model.layer3.weight": fp8_weight3, - "model.diffusion_model.layer3.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertIn("layer3", result) - self.assertNotIn("layer2", result) # Layer 2 not quantized - self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") - self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled") - - def test_all_layers_quantized(self): - """Test that uniform quantization (all layers) returns None""" - # Create FP8 tensors by converting from float32 - fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - # All layers have scale_weight - "model.diffusion_model.layer1.weight": fp8_weight1, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - "model.diffusion_model.layer2.weight": fp8_weight2, - "model.diffusion_model.layer2.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - # If all layers are quantized, it's not mixed precision - # Should return None to use legacy scaled_fp8_ops path - self.assertIsNone(result) - - def test_fp8_e5m2_detection(self): - """Test detection of FP8 E5M2 format""" - # Create FP8 E5M2 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2) - state_dict = { - "model.diffusion_model.layer1.weight": fp8_weight, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled") - - def test_invalid_metadata(self): - """Test with invalid metadata format""" - state_dict = { - "model.diffusion_model._quantization_metadata": "invalid_string", - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - # Should return None on invalid metadata - self.assertIsNone(result) - - def test_different_prefix(self): - """Test with different model prefix (audio model)""" - # Create FP8 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - "model.model.layer1.weight": fp8_weight, - "model.model.layer1.scale_weight": torch.tensor(1.0), - "model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - } - result = model_detection.detect_layer_quantization(state_dict, "model.model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - - -class TestPickOperationsIntegration(unittest.TestCase): - """Test pick_operations with model_config parameter""" - - def test_backward_compatibility(self): - """Test that pick_operations works without model_config (legacy)""" - from comfy import ops - - # Should work without model_config parameter - result = ops.pick_operations(None, None) - self.assertIsNotNone(result) - self.assertEqual(result, ops.disable_weight_init) - - def test_with_model_config_no_quant(self): - """Test with model_config but no quantization""" - from comfy import ops, supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = None - - result = ops.pick_operations(None, None, model_config=model_config) - self.assertIsNotNone(result) - # Should use standard path - self.assertEqual(result, ops.disable_weight_init) - - def test_legacy_scaled_fp8(self): - """Test that legacy scaled_fp8 still works""" - from comfy import ops, supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.scaled_fp8 = torch.float8_e4m3fn - - result = ops.pick_operations( - None, None, - scaled_fp8=torch.float8_e4m3fn, - model_config=model_config - ) - self.assertIsNotNone(result) - # Should return scaled_fp8_ops (the returned class is the inner class) - # Check that it's not the standard disable_weight_init - self.assertNotEqual(result, ops.disable_weight_init) - # Verify it has Linear class - self.assertTrue(hasattr(result, 'Linear')) - - -if __name__ == "__main__": - unittest.main() - diff --git a/tests-unit/comfy_test/test_quant_registry.py b/tests-unit/comfy_test/test_quant_registry.py deleted file mode 100644 index 5c624b1db9d8..000000000000 --- a/tests-unit/comfy_test/test_quant_registry.py +++ /dev/null @@ -1,399 +0,0 @@ -""" -Unit tests for tensor subclass quantization system. -Tests the new QuantizedTensorFP8 subclass and operation handlers. -""" - -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from comfy import ops -from comfy import quant_ops - - -class TestQuantizedTensorFP8(unittest.TestCase): - """Test the QuantizedTensorFP8 tensor subclass""" - - def test_creation(self): - """Test creating a QuantizedTensorFP8""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) - - self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._scale, scale) - self.assertEqual(qt._orig_dtype, torch.bfloat16) - - def test_dequantize(self): - """Test explicit dequantization""" - # Create a simple FP8 tensor - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32) - dequantized = qt.dequantize() - - # Dequantized should be approximately ones * 3.0 - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_repr(self): - """Test string representation""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.5) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) - repr_str = repr(qt) - - self.assertIn("QuantizedTensorFP8", repr_str) - self.assertIn("shape", repr_str) - self.assertIn("scale", repr_str) - - -class TestOperationRegistry(unittest.TestCase): - """Test the operation registry system""" - - def test_registry_basics(self): - """Test that operations are registered""" - registered_ops = quant_ops.list_registered_ops() - - # Check that key operations are registered - self.assertIn(torch.ops.aten.linear.default, registered_ops) - self.assertIn(torch.ops.aten.silu.default, registered_ops) - self.assertIn(torch.ops.aten.layer_norm.default, registered_ops) - self.assertIn(torch.ops.aten.add.Tensor, registered_ops) - self.assertIn(torch.ops.aten.mul.Tensor, registered_ops) - - def test_get_handler(self): - """Test getting a registered handler""" - handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default) - self.assertIsNotNone(handler) - self.assertTrue(callable(handler)) - - def test_custom_registration(self): - """Test registering a custom operation""" - - # Define a custom handler - @quant_ops.register_quant_op(torch.ops.aten.relu.default) - def custom_relu_handler(func, args, kwargs): - return func(*args, **kwargs) - - # Verify registration - handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default) - self.assertIsNotNone(handler) - self.assertEqual(handler, custom_relu_handler) - - -class TestLinearHandler(unittest.TestCase): - """Test the linear operation handler""" - - def test_linear_with_quantized_weight(self): - """Test F.linear with quantized weight""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized weight - weight_fp32 = torch.randn(256, 128, dtype=torch.float32) - scale = torch.tensor(2.0) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) - - # Create input - input_tensor = torch.randn(16, 128, dtype=torch.float32) - - # Call linear (should trigger dispatch) - output = torch.nn.functional.linear(input_tensor, weight_q, bias=None) - - # Verify output shape - self.assertEqual(output.shape, (16, 256)) - - # Verify it's approximately correct (allowing for FP8 quantization error) - # Note: FP8 has limited precision, so use very loose tolerance - expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None) - # Just check that it's in the right ballpark (within 50% error on average) - mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large") - - def test_linear_with_bias(self): - """Test F.linear with quantized weight and bias""" - weight_fp32 = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) - - input_tensor = torch.randn(8, 32, dtype=torch.float32) - bias = torch.randn(64, dtype=torch.float32) - - output = torch.nn.functional.linear(input_tensor, weight_q, bias) - - self.assertEqual(output.shape, (8, 64)) - - -class TestActivationHandlers(unittest.TestCase): - """Test activation function handlers""" - - def test_silu_with_quantized_input(self): - """Test SiLU with quantized input""" - # Create quantized input - input_fp32 = torch.randn(16, 128, dtype=torch.float32) - scale = torch.tensor(1.0) - input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) - input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) - - # Apply SiLU - output = torch.nn.functional.silu(input_q) - - # Should return a QuantizedTensorFP8 - self.assertIsInstance(output, quant_ops.QuantizedTensorFP8) - - # Verify approximate correctness - expected = torch.nn.functional.silu(input_fp32) - output_dq = output.dequantize() - self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2)) - - def test_layernorm_dequantizes(self): - """Test that LayerNorm dequantizes input""" - # Create quantized input - input_fp32 = torch.randn(16, 128, dtype=torch.float32) - scale = torch.tensor(1.0) - input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) - input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) - - # Apply LayerNorm - weight = torch.ones(128) - bias = torch.zeros(128) - output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias) - - # Should NOT be quantized (LayerNorm breaks quantization) - self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8) - self.assertEqual(output.dtype, torch.float32) - - -class TestElementwiseHandlers(unittest.TestCase): - """Test element-wise operation handlers""" - - def test_add_mixed_tensors(self): - """Test addition with mixed quantized/non-quantized tensors""" - # Create quantized tensor - a_fp32 = torch.ones(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) - - # Non-quantized tensor - b = torch.ones(10, 20, dtype=torch.float32) * 2.0 - - # Add them - result = a_q + b - - # Should be dequantized - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_mul_quantized_tensors(self): - """Test multiplication of two quantized tensors""" - a_fp32 = torch.ones(10, 20) * 2.0 - scale_a = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32) - - b_fp32 = torch.ones(10, 20) * 3.0 - scale_b = torch.tensor(1.0) - b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn) - b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32) - - result = a_q * b_q - - # Should be dequantized - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -class TestMixedPrecisionOps(unittest.TestCase): - """Test MixedPrecisionOps integration""" - - def test_linear_layer_creation(self): - """Test that MixedPrecisionOps.Linear can be created""" - layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32) - - self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear) - self.assertFalse(layer._quantization_initialized) - self.assertIsNone(layer.quant_format) - - def test_layer_quant_config_detection(self): - """Test that layer quantization config is detected during load""" - # Set up layer config - ops.MixedPrecisionOps._layer_quant_config = { - "test_layer": { - "format": "fp8_e4m3fn", - "params": {} - } - } - - # Create a state dict with quantized weight - weight_fp32 = torch.randn(256, 128, dtype=torch.float32) - scale = torch.tensor(2.0) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - - state_dict = { - "model.diffusion_model.test_layer.weight": weight_fp8, - "model.diffusion_model.test_layer.scale_weight": scale, - } - - # Create layer and load - layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn) - layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn)) - - # Manually call _load_from_state_dict - layer._load_from_state_dict( - state_dict, - prefix="model.diffusion_model.test_layer.", - local_metadata={}, - strict=True, - missing_keys=[], - unexpected_keys=[], - error_msgs=[] - ) - - # Verify quantization was initialized - self.assertTrue(layer._quantization_initialized) - self.assertEqual(layer.quant_format, "fp8_e4m3fn") - self.assertIsNotNone(layer.quant_scale) - - # Verify weight is wrapped - self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8) - - # Clean up - ops.MixedPrecisionOps._layer_quant_config = {} - - -class TestBackwardCompatibility(unittest.TestCase): - """Test backward compatibility with legacy systems""" - - def test_legacy_ops_classes_exist(self): - """Test that legacy ops classes still exist""" - self.assertTrue(hasattr(ops, 'disable_weight_init')) - self.assertTrue(hasattr(ops, 'manual_cast')) - self.assertTrue(hasattr(ops, 'fp8_ops')) - self.assertTrue(hasattr(ops, 'scaled_fp8_ops')) - - def test_pick_operations_legacy_path(self): - """Test pick_operations returns correct class for legacy cases""" - # Test standard case - result = ops.pick_operations(torch.float32, torch.float32) - self.assertEqual(result, ops.disable_weight_init) - - # Test manual cast case - result = ops.pick_operations(torch.float32, torch.float16) - self.assertEqual(result, ops.manual_cast) - - -class TestFP8LinearUnification(unittest.TestCase): - """Test that fp8_linear now uses the unified tensor subclass infrastructure""" - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8") - def test_fp8_linear_uses_tensor_subclass(self): - """Verify fp8_linear wraps tensors in QuantizedTensorFP8""" - torch.manual_seed(42) - - # Create a mock Linear layer with FP8 weight - linear = ops.fp8_ops.Linear(4, 3, bias=True) - linear.weight = torch.nn.Parameter( - torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn), - requires_grad=False - ) - linear.bias = torch.nn.Parameter( - torch.randn(3, dtype=torch.bfloat16), - requires_grad=False - ) - linear.scale_weight = torch.tensor(1.0) - linear.scale_input = None # No input scaling - - # Create input - input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) - - # Call fp8_linear - should work without errors - try: - result = ops.fp8_linear(linear, input_tensor) - self.assertIsNotNone(result) - self.assertEqual(result.shape, (2, 3)) - except Exception as e: - # On CPU or unsupported hardware, _scaled_mm might not be available - # but the function should still complete without syntax errors - pass - - def test_fp8_linear_maintains_signature(self): - """Verify fp8_linear maintains its original function signature""" - import inspect - sig = inspect.signature(ops.fp8_linear) - params = list(sig.parameters.keys()) - - # Should have 'self' and 'input' parameters - self.assertIn('self', params) - self.assertIn('input', params) - self.assertEqual(len(params), 2) - - def test_fp8_linear_returns_none_for_non_fp8(self): - """Verify fp8_linear returns None for non-FP8 weights""" - # Create a Linear layer with BF16 weight (not FP8) - linear = ops.disable_weight_init.Linear(4, 3, bias=False) - linear.weight = torch.nn.Parameter( - torch.randn(3, 4, dtype=torch.bfloat16), - requires_grad=False - ) - - input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) - - # Should return None for non-FP8 weights - result = ops.fp8_linear(linear, input_tensor) - self.assertIsNone(result) - - def test_fp8_ops_linear_uses_fp8_linear(self): - """Verify fp8_ops.Linear still uses fp8_linear in forward pass""" - linear = ops.fp8_ops.Linear(4, 3, bias=False) - - # Verify the class has the forward_comfy_cast_weights method - self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights')) - - # The forward_comfy_cast_weights should attempt to call fp8_linear - # (we can't easily test this without mocking, but we verify structure) - import inspect - source = inspect.getsource(linear.forward_comfy_cast_weights) - self.assertIn('fp8_linear', source) - - -if __name__ == "__main__": - unittest.main() From 2a8b8264426c311eebee0ec9eb167f20f678c952 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 08:52:50 +0100 Subject: [PATCH 38/49] ruff lint --- comfy/model_base.py | 2 +- comfy/model_detection.py | 2 +- comfy/ops.py | 20 ++--- comfy/quant_ops.py | 76 +++++++++---------- .../comfy_quant/test_mixed_precision.py | 68 ++++++++--------- tests-unit/comfy_quant/test_quant_registry.py | 67 ++++++++-------- 6 files changed, 117 insertions(+), 118 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7b4651f8eb1b..f850cc402049 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -326,7 +326,7 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_ if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - + # Save mixed precision metadata if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: metadata = { diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 335ccbd17ee4..c4fc27742a3a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -16,7 +16,7 @@ def detect_layer_quantization(metadata): logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") return quant_metadata["layers"] else: - raise ValueError(f"Invalid quantization metadata format") + raise ValueError("Invalid quantization metadata format") return None diff --git a/comfy/ops.py b/comfy/ops.py index 8af1e949dcd2..e2d76d7a97bf 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -325,7 +325,7 @@ class Embedding(disable_weight_init.Embedding): def fp8_linear(self, input): """ - Legacy FP8 linear function for backward compatibility. + Legacy FP8 linear function for backward compatibility. Uses QuantizedTensor subclass for dispatch. """ dtype = self.weight.dtype @@ -339,7 +339,7 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype - + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) @@ -354,14 +354,14 @@ def fp8_linear(self, input): scale_input = torch.ones((), device=input.device, dtype=torch.float32) else: scale_input = scale_input.to(input.device) - + # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - + if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -503,8 +503,8 @@ def __init__( def reset_parameters(self): return None - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): device = self.factory_kwargs["device"] @@ -520,10 +520,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) if quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - + mixin = QUANT_FORMAT_MIXINS[quant_format] self.layout_type = mixin["layout_type"] - + layout_params = { 'scale': state_dict.pop(f"{prefix}weight_scale", None), 'orig_dtype': MixedPrecisionOps._compute_dtype @@ -558,7 +558,7 @@ def forward(self, input, *args, **kwargs): not isinstance(input, QuantizedTensor)): input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) - + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: @@ -566,7 +566,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps - + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 96d2fa03fdbd..aa1a231bd15c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -31,7 +31,7 @@ def register_generic_util(torch_op): Decorator to register a generic utility that works for all layouts. Args: torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - + Example: @register_generic_util(torch.ops.aten.detach.default) def generic_detach(func, args, kwargs): @@ -78,10 +78,10 @@ def _copy_layout_params(params): class QuantizedLayout: """ Base class for quantization layouts. - + A layout encapsulates the format-specific logic for quantization/dequantization and provides a uniform interface for extracting raw tensors needed for computation. - + New quantization formats should subclass this and implement the required methods. """ @classmethod @@ -90,8 +90,8 @@ def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: @staticmethod def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError(f"TensorLayout must implement dequantize()") - + raise NotImplementedError("TensorLayout must implement dequantize()") + @classmethod def get_plain_tensors(cls, qtensor) -> torch.Tensor: raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") @@ -100,45 +100,45 @@ def get_plain_tensors(cls, qtensor) -> torch.Tensor: class QuantizedTensor(torch.Tensor): """ Universal quantized tensor that works with any layout. - + This tensor subclass uses a pluggable layout system to support multiple quantization formats (FP8, INT4, INT8, etc.) without code duplication. - + The layout_type determines format-specific behavior, while common operations (detach, clone, to) are handled generically. - + Attributes: _qdata: The quantized tensor data _layout_type: Layout class (e.g., TensorCoreFP8Layout) _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ - + @staticmethod def __new__(cls, qdata, layout_type, layout_params): """ Create a quantized tensor. - + Args: qdata: The quantized data tensor layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ return torch.Tensor._make_subclass(cls, qdata, require_grad=False) - + def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata.contiguous() self._layout_type = layout_type self._layout_params = layout_params - + def __repr__(self): layout_name = self._layout_type.__name__ param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - + @property def layout_type(self): return self._layout_type - + def __tensor_flatten__(self): """ Tensor flattening protocol for proper device movement. @@ -147,7 +147,7 @@ def __tensor_flatten__(self): ctx = { "layout_type": self._layout_type, } - + tensor_params = {} non_tensor_params = {} for k, v in self._layout_params.items(): @@ -155,17 +155,17 @@ def __tensor_flatten__(self): tensor_params[k] = v else: non_tensor_params[k] = v - + ctx["tensor_param_keys"] = list(tensor_params.keys()) ctx["non_tensor_params"] = non_tensor_params - + for k, v in tensor_params.items(): attr_name = f"_layout_param_{k}" object.__setattr__(self, attr_name, v) inner_tensors.append(attr_name) - + return inner_tensors, ctx - + @staticmethod def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): """ @@ -174,41 +174,41 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): """ layout_type = ctx["layout_type"] layout_params = dict(ctx["non_tensor_params"]) - + for key in ctx["tensor_param_keys"]: attr_name = f"_layout_param_{key}" layout_params[key] = inner_tensors[attr_name] - + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) - + @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) return cls(qdata, layout_type, layout_params) - + def dequantize(self) -> torch.Tensor: return self._layout_type.dequantize(self._qdata, **self._layout_params) - + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} - + # Step 1: Check generic utilities first (detach, clone, to, etc.) if func in _GENERIC_UTILS: return _GENERIC_UTILS[func](func, args, kwargs) - + # Step 2: Check layout-specific handlers (linear, matmul, etc.) layout_type = _get_layout_from_args(args) if layout_type and func in _LAYOUT_REGISTRY: handler = _LAYOUT_REGISTRY[func].get(layout_type) if handler: return handler(func, args, kwargs) - + # Step 3: Fallback to dequantization if isinstance(args[0] if args else None, QuantizedTensor): logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) - + @classmethod def _dequant_and_fallback(cls, func, args, kwargs): def dequant_arg(arg): @@ -217,7 +217,7 @@ def dequant_arg(arg): elif isinstance(arg, (list, tuple)): return type(arg)(dequant_arg(a) for a in arg) return arg - + new_args = dequant_arg(args) new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) @@ -239,13 +239,13 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= f"QuantizedTensor: dtype conversion requested to {target_dtype}, " f"but not supported for quantized tensors. Ignoring dtype." ) - + if target_layout is not None and target_layout != torch.strided: logging.warning( f"QuantizedTensor: layout change requested to {target_layout}, " f"but not supported. Ignoring layout." ) - + # Handle device transfer current_device = qt._qdata.device if target_device is not None: @@ -254,7 +254,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= target_device = torch.device(target_device) if isinstance(current_device, str): current_device = torch.device(current_device) - + if target_device != current_device: logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") new_q_data = qt._qdata.to(device=target_device) @@ -262,7 +262,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") return new_qt - + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") return qt @@ -318,7 +318,7 @@ def generic_to_dtype_layout(func, args, kwargs): def generic_copy_(func, args, kwargs): qt_dest = args[0] src = args[1] - + if isinstance(qt_dest, QuantizedTensor): if isinstance(src, QuantizedTensor): # Copy from another quantized tensor @@ -383,15 +383,15 @@ def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - + out_dtype = kwargs.get("out_dtype") if out_dtype is None: out_dtype = input_tensor._layout_params['orig_dtype'] - + weight_t = plain_weight.t() tensor_2d = False @@ -424,7 +424,7 @@ def fp8_linear(func, args, kwargs): return QuantizedTensor(output, TensorCoreFP8Layout, output_params) else: return output - + except Exception as e: raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index e3455276063d..1102f9bd4b28 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -16,7 +16,7 @@ def __init__(self, operations=ops.disable_weight_init): self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) - + def forward(self, x): x = self.layer1(x) x = torch.nn.functional.relu(x) @@ -32,10 +32,10 @@ def test_all_layers_standard(self): """Test that model with no quantization works normally""" # Configure no quantization ops.MixedPrecisionOps._layer_quant_config = {} - + # Create model model = SimpleModel(operations=ops.MixedPrecisionOps) - + # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) @@ -43,19 +43,19 @@ def test_all_layers_standard(self): model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) - + # Initialize weight_function and bias_function for layer in [model.layer1, model.layer2, model.layer3]: layer.weight_function = [] layer.bias_function = [] - + # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) self.assertEqual(output.dtype, torch.bfloat16) - + def test_mixed_precision_load(self): """Test loading a mixed precision model from state dict""" # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard @@ -70,52 +70,52 @@ def test_mixed_precision_load(self): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) - + state_dict = { # Layer 1: FP8 E4M3FN "layer1.weight": fp8_weight1, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), - + # Layer 2: Standard BF16 "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), - + # Layer 3: FP8 E4M3FN "layer3.weight": fp8_weight3, "layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } - + # Create model and load state dict (strict=False because custom loading pops keys) model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict, strict=False) - + # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) - + # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) - + # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) - + # Verify scales were loaded self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) - + # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) - + def test_state_dict_quantized_preserved(self): """Test that quantized weights are preserved in state_dict()""" # Configure mixed precision @@ -126,7 +126,7 @@ def test_state_dict_quantized_preserved(self): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict1 = { @@ -138,22 +138,22 @@ def test_state_dict_quantized_preserved(self): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict1, strict=False) - + # Save state dict state_dict2 = model.state_dict() - + # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) - + # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) - + def test_weight_function_compatibility(self): """Test that weight_function (LoRA) works with quantized layers""" # Configure FP8 quantization @@ -164,7 +164,7 @@ def test_weight_function_compatibility(self): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict = { @@ -176,24 +176,24 @@ def test_weight_function_compatibility(self): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict, strict=False) - + # Add a weight function (simulating LoRA) # This should trigger dequantization during forward pass def apply_lora(weight): lora_delta = torch.randn_like(weight) * 0.01 return weight + lora_delta - + model.layer1.weight_function.append(apply_lora) - + # Forward pass should work with LoRA (triggers weight_function path) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) - + def test_error_handling_unknown_format(self): """Test that unknown formats raise error""" # Configure with unknown format @@ -204,7 +204,7 @@ def test_error_handling_unknown_format(self): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create state dict state_dict = { "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), @@ -214,7 +214,7 @@ def test_error_handling_unknown_format(self): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS model = SimpleModel(operations=ops.MixedPrecisionOps) with self.assertRaises(KeyError): diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 263581417177..26e91a7ee7d0 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -11,51 +11,51 @@ class TestQuantizedTensor(unittest.TestCase): """Test the QuantizedTensor subclass with FP8 layout""" - + def test_creation(self): """Test creating a QuantizedTensor with TensorCoreFP8Layout""" fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(2.0) layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) self.assertEqual(qt._layout_type, TensorCoreFP8Layout) - + def test_dequantize(self): """Test explicit dequantization""" fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(3.0) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) dequantized = qt.dequantize() - + self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) scale = torch.tensor(1.5) - + qt = QuantizedTensor.from_float( - float_tensor, - TensorCoreFP8Layout, + float_tensor, + TensorCoreFP8Layout, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt.shape, (64, 32)) - + # Verify dequantization gives approximately original values dequantized = qt.dequantize() mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() @@ -64,48 +64,48 @@ def test_from_float(self): class TestGenericUtilities(unittest.TestCase): """Test generic utility operations""" - + def test_detach(self): """Test detach operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Detach should return a new QuantizedTensor qt_detached = qt.detach() - + self.assertIsInstance(qt_detached, QuantizedTensor) self.assertEqual(qt_detached.shape, qt.shape) self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) - + def test_clone(self): """Test clone operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Clone should return a new QuantizedTensor qt_cloned = qt.clone() - + self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertEqual(qt_cloned.shape, qt.shape) self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) - + # Verify it's a deep copy self.assertIsNot(qt_cloned._qdata, qt._qdata) - + def test_to_device(self): """Test device transfer""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Moving to same device should work (CPU to CPU) qt_cpu = qt.to('cpu') - + self.assertIsInstance(qt_cpu, QuantizedTensor) self.assertEqual(qt_cpu.device.type, 'cpu') self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') @@ -113,64 +113,63 @@ def test_to_device(self): class TestTensorCoreFP8Layout(unittest.TestCase): """Test the TensorCoreFP8Layout implementation""" - + def test_quantize(self): """Test quantization method""" float_tensor = torch.randn(32, 64, dtype=torch.float32) scale = torch.tensor(1.5) - + qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) self.assertEqual(qdata.shape, float_tensor.shape) self.assertIn('scale', layout_params) self.assertIn('orig_dtype', layout_params) self.assertEqual(layout_params['orig_dtype'], torch.float32) - + def test_dequantize(self): """Test dequantization method""" float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 scale = torch.tensor(1.0) - + qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - + # Should approximately match original self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) class TestFallbackMechanism(unittest.TestCase): """Test fallback for unsupported operations""" - + def test_unsupported_op_dequantizes(self): """Test that unsupported operations fall back to dequantization""" # Set seed for reproducibility torch.manual_seed(42) - + # Create quantized tensor a_fp32 = torch.randn(10, 20, dtype=torch.float32) scale = torch.tensor(1.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} a_q = QuantizedTensor.from_float( a_fp32, TensorCoreFP8Layout, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + # Call an operation that doesn't have a registered handler # For example, torch.abs result = torch.abs(a_q) - + # Should work via fallback (dequantize → abs → return) self.assertNotIsInstance(result, QuantizedTensor) expected = torch.abs(a_fp32) From 70acf793465c57a4780cc8ffdabb2d2231ebf58e Mon Sep 17 00:00:00 2001 From: lspindler Date: Thu, 16 Oct 2025 16:07:43 +0200 Subject: [PATCH 39/49] Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. --- comfy/model_base.py | 10 +- comfy/model_detection.py | 125 +++++++++ comfy/ops.py | 484 ++++++++++++++++++++++++++++++++- comfy/supported_models_base.py | 1 + 4 files changed, 618 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index e877f19ac6c2..e0589ba92095 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -134,7 +134,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -332,6 +332,14 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_ if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) + + # Save mixed precision metadata + if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: + metadata = { + "format_version": "1.0", + "layers": self.model_config.layer_quant_config + } + unet_state_dict["_quantization_metadata"] = metadata unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 141f1e164834..c7ef7dab6e02 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,6 +6,125 @@ import logging import torch + +# ============================================================================== +# Quantization Detection Functions +# ============================================================================== + +def normalize_layer_name(full_key, known_prefixes): + """ + Strip model prefix and parameter suffix from a state dict key. + + Args: + full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight") + known_prefixes: List of known model prefixes to strip + + Returns: + Normalized layer name (e.g., "layer1") + """ + name = full_key + + # Strip model prefix + for prefix in known_prefixes: + if name.startswith(prefix): + name = name[len(prefix):] + break + + # Remove parameter suffix + for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]: + if name.endswith(suffix): + name = name[:-len(suffix)] + break + + return name + + +def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): + """ + Detect per-layer quantization configuration from state dict. + + Detection priority: + 1. Check for _quantization_metadata key (new format) + 2. Check for scaled_fp8 key (legacy format - return None) + 3. Check for per-layer scale_weight patterns (mixed detection) + 4. No quantization detected (return None) + + Args: + state_dict: Model state dictionary + prefix: Key prefix for model layers + + Returns: + Dict mapping layer names to quantization configs, or None for legacy/no quantization. + + Example return value: + { + "input_blocks.5.1.transformer_blocks.0.attn1.to_q": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": True} + }, + "middle_block.1.transformer_blocks.0.attn2.to_k": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": True} + } + } + """ + + # 1. Check for new metadata format + metadata_key = f"{prefix}_quantization_metadata" + if metadata_key in state_dict: + try: + metadata = state_dict.pop(metadata_key) + if isinstance(metadata, dict) and "layers" in metadata: + logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})") + return metadata["layers"] + else: + logging.warning(f"Invalid quantization metadata format, ignoring") + except Exception as e: + logging.error(f"Failed to parse quantization metadata: {e}") + return None + + # 2. Check for legacy scaled_fp8 marker + # If present, return None to use legacy code path + scaled_fp8_key = f"{prefix}scaled_fp8" + if scaled_fp8_key in state_dict: + logging.debug("Detected legacy scaled_fp8 format, using legacy code path") + return None + + # 3. Check for per-layer scale patterns (mixed precision without metadata) + # Look for layers that have scale_weight but not all layers have it + known_prefixes = [prefix] + layer_configs = {} + layers_with_scale = set() + layers_with_weight = set() + + for key in state_dict.keys(): + if key.startswith(prefix): + if key.endswith(".scale_weight"): + layer_name = normalize_layer_name(key, known_prefixes) + layers_with_scale.add(layer_name) + # Detect format based on weight dtype + weight_key = f"{prefix}{layer_name}.weight" + if weight_key in state_dict: + weight_dtype = state_dict[weight_key].dtype + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled" + layer_configs[layer_name] = { + "format": format_name, + "params": {"use_fp8_matmul": True} + } + elif key.endswith(".weight") and not key.endswith(".scale_weight"): + layer_name = normalize_layer_name(key, known_prefixes) + layers_with_weight.add(layer_name) + + # If we found scale_weight on some but not all layers, it's mixed precision + if layer_configs and len(layers_with_scale) < len(layers_with_weight): + logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers") + return layer_configs + + # 4. No quantization detected + return None + + def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -701,6 +820,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal else: model_config.optimizations["fp8"] = True + # Detect per-layer quantization (mixed precision) + layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix) + if layer_quant_config: + model_config.layer_quant_config = layer_quant_config + logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + return model_config def unet_prefix_from_state_dict(state_dict): diff --git a/comfy/ops.py b/comfy/ops.py index 934e21261edc..fac2be7282da 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -478,7 +478,457 @@ def forward_comfy_cast_weights(self, input): def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): + +# ============================================================================== +# Quantization Format Registry System +# ============================================================================== + +class QuantFormatHandler: + """ + Base class for all quantization format handlers. + + A handler encapsulates the logic for a specific quantization format + (e.g., FP8 scaled, MX formats) and manages the quantization + parameters and forward pass for quantized layers. + """ + + def __init__(self, layer, **config): + """ + Initialize handler for a specific layer. + + Args: + layer: The nn.Module layer (Linear, Conv2d, etc.) + **config: Format-specific configuration + """ + self.layer = layer + self.config = config + + def setup_parameters(self): + """ + Initialize quantization parameters on the layer. + Called during layer construction or load_state_dict. + + Subclasses should create parameters like scale_weight, scale_input, etc. + and attach them to self.layer. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()") + + def forward(self, *args, **kwargs): + """ + Execute quantized forward pass. + + Signature matches the layer's expected forward pass. + Handler accesses layer parameters via self.layer (weight, bias, etc.) + + Args: + *args: Positional arguments matching layer forward signature + **kwargs: Keyword arguments matching layer forward signature + + Returns: + Layer output tensor + + Examples: + Linear: forward(input) + Conv2d: forward(input) + GroupNorm: forward(input) + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement forward()") + + def load_state_dict(self, state_dict, prefix): + """ + Load quantization parameters from state dict. + + Args: + state_dict: State dictionary + prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.") + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()") + + def state_dict(self, prefix): + """ + Save quantization parameters to state dict. + + Args: + prefix: Key prefix for this layer + + Returns: + Dictionary of quantization parameters with full keys + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()") + + def convert_weight(self, weight, inplace=False): + """ + Convert weight from quantized to full precision (dequantize). + + Args: + weight: Quantized weight tensor + inplace: Whether to modify in-place + + Returns: + Dequantized weight tensor + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()") + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): + """ + Convert and set weight from full precision to quantized. + + Args: + weight: Full precision weight tensor + inplace_update: Whether to update layer weight in-place + seed: Random seed for stochastic rounding + return_weight: If True, return quantized weight without setting + + Returns: + Quantized weight if return_weight=True, else None + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()") + + +class QuantFormatRegistry: + """ + Global registry for quantization formats. + + Formats are registered with a unique name and handler class. + Custom formats can be registered by custom nodes. + """ + + _formats = {} + + @classmethod + def register(cls, name, handler_class, **default_config): + """ + Register a new quantization format. + + Args: + name: Unique format identifier (e.g., "fp8_e4m3fn_scaled") + handler_class: Handler class implementing QuantFormatHandler + **default_config: Default configuration parameters + + Example: + QuantFormatRegistry.register( + "fp8_e4m3fn_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=False, + use_fp8_matmul=True, + ) + """ + if not issubclass(handler_class, QuantFormatHandler): + raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}") + + cls._formats[name] = { + "handler": handler_class, + "config": default_config.copy() + } + logging.debug(f"Registered quantization format: {name}") + + @classmethod + def get(cls, name, **override_config): + """ + Get format info with optional config overrides. + + Args: + name: Format identifier + **override_config: Configuration overrides + + Returns: + Dict with 'handler' (class) and 'config' (dict) keys + + Raises: + ValueError: If format name not registered + """ + if name not in cls._formats: + available = ", ".join(cls._formats.keys()) if cls._formats else "none" + raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}") + + format_info = cls._formats[name].copy() + # Merge override_config into default config + config = format_info["config"].copy() + config.update(override_config) + format_info["config"] = config + return format_info + + @classmethod + def list_formats(cls): + """List all registered format names""" + return list(cls._formats.keys()) + + @classmethod + def is_registered(cls, name): + """Check if a format is registered""" + return name in cls._formats + + +class FP8ScaledHandler(QuantFormatHandler): + """ + Handler for FP8 quantization with per-tensor scaling. + + Supports both weight-only and weight+activation quantization. + Compatible with existing fp8_linear implementation. + """ + + def setup_parameters(self): + """Initialize scale_weight and optionally scale_input""" + device = self.layer.weight.device + dtype = torch.float32 + + # Always have scale_weight for FP8 + if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None: + self.layer.scale_weight = torch.nn.Parameter( + torch.ones((), device=device, dtype=dtype), + requires_grad=False + ) + + # scale_input is optional (for activation quantization) + if self.config.get("quantize_activation", False): + if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None: + self.layer.scale_input = torch.nn.Parameter( + torch.ones((), device=device, dtype=dtype), + requires_grad=False + ) + else: + self.layer.scale_input = None + + def forward(self, *args, **kwargs): + """ + FP8 forward pass with optional activation quantization. + Supports Linear layers (Conv2d in future). + """ + # Detect layer type and dispatch + if isinstance(self.layer, torch.nn.Linear): + return self._forward_linear(*args, **kwargs) + else: + raise NotImplementedError( + f"FP8ScaledHandler not implemented for {type(self.layer).__name__}" + ) + + def _forward_linear(self, input): + """FP8 forward for Linear layers""" + # Try fast path with fp8_linear if enabled + if self.config.get("use_fp8_matmul", False) and not self.layer.training: + try: + result = fp8_linear(self.layer, input) + if result is not None: + return result + except Exception as e: + logging.debug(f"FP8 matmul failed, falling back to standard path: {e}") + + # Standard path: dequantize and compute + weight, bias = cast_bias_weight(self.layer, input) + + # Dequantize weight + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + + # Apply weight functions (LoRA, etc.) - they see dequantized weights + if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0: + weight = weight * scale + for f in self.layer.weight_function: + weight = f(weight) + else: + weight = weight * scale + + if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0: + for f in self.layer.bias_function: + bias = f(bias) if bias is not None else None + + # Execute linear operation + # Optimization: multiply by scale on smaller tensor + if weight.numel() < input.numel() and len(self.layer.weight_function) == 0: + return torch.nn.functional.linear(input, weight, bias) + else: + return torch.nn.functional.linear(input, weight, bias) + + def load_state_dict(self, state_dict, prefix): + """Load scale parameters from state dict""" + scale_weight_key = f"{prefix}scale_weight" + if scale_weight_key in state_dict: + self.layer.scale_weight.data.copy_(state_dict[scale_weight_key]) + + scale_input_key = f"{prefix}scale_input" + if scale_input_key in state_dict and self.layer.scale_input is not None: + self.layer.scale_input.data.copy_(state_dict[scale_input_key]) + + def state_dict(self, prefix): + """Save scale parameters to state dict""" + result = {f"{prefix}scale_weight": self.layer.scale_weight} + if self.layer.scale_input is not None: + result[f"{prefix}scale_input"] = self.layer.scale_input + return result + + def convert_weight(self, weight, inplace=False): + """Dequantize: multiply by scale""" + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + if inplace: + weight *= scale + return weight + return weight * scale + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): + """Quantize: divide by scale with stochastic rounding""" + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + quantized = comfy.float.stochastic_rounding( + weight / scale, + self.layer.weight.dtype, + seed=seed + ) + + if return_weight: + return quantized + + if inplace_update: + self.layer.weight.data.copy_(quantized) + else: + self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False) + + +# ============================================================================== +# Mixed Precision Operations +# ============================================================================== + +class MixedPrecisionOps(disable_weight_init): + """ + Operations class supporting per-layer quantization (mixed precision). + + This class enables different layers to use different quantization formats + within the same model (e.g., some layers FP8, others BF16). + + Layer-specific quantization is configured via _layer_quant_config class variable, + which is set by pick_operations() when a model has mixed precision. + """ + + _layer_quant_config = {} # Class variable set by pick_operations() + + class Linear(disable_weight_init.Linear): + """Linear layer with optional per-layer quantization""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.quant_handler = None + self._handler_initialized = False + + def reset_parameters(self): + # Don't allocate weights - return None like disable_weight_init + return None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + """ + Called by PyTorch during load_state_dict. + This is where we initialize the handler since we now know the layer name. + """ + if not self._handler_initialized: + # Normalize layer name from prefix + layer_name = prefix.rstrip('.') + + # Strip known model prefixes + for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: + if layer_name.startswith(model_prefix): + layer_name = layer_name[len(model_prefix):] + break + + # Check if this layer has quantization config + # Access via parent class since _layer_quant_config is a class variable + if layer_name in MixedPrecisionOps._layer_quant_config: + config = MixedPrecisionOps._layer_quant_config[layer_name] + try: + format_info = QuantFormatRegistry.get( + config["format"], + **config.get("params", {}) + ) + + # Initialize handler + self.quant_handler = format_info["handler"](self, **format_info["config"]) + self.quant_handler.setup_parameters() + + # Let handler load its parameters (scale_weight, etc.) + self.quant_handler.load_state_dict(state_dict, prefix) + + logging.debug(f"Initialized {config['format']} handler for layer {layer_name}") + except ValueError as e: + # Format not registered - fall back to standard precision + logging.warning( + f"Quantization format '{config['format']}' not registered for layer {layer_name}. " + f"Falling back to standard precision. Error: {e}" + ) + self.quant_handler = None + except Exception as e: + logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}") + self.quant_handler = None + + self._handler_initialized = True + + # Call parent to load weight and bias + super()._load_from_state_dict( + state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + """Save layer parameters including quantization metadata""" + super()._save_to_state_dict(destination, prefix, keep_vars) + + # Save handler parameters (scale_weight, etc.) + if self.quant_handler: + handler_dict = self.quant_handler.state_dict(prefix) + destination.update(handler_dict) + + def forward_comfy_cast_weights(self, input): + """Forward pass with optional quantization""" + if self.quant_handler: + # Use handler for quantized forward + return self.quant_handler.forward(input) + else: + # Standard path for non-quantized layers + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + """Main forward pass""" + run_every_op() + # Same logic as disable_weight_init.Linear + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + @classmethod + def conv_nd(s, dims, *args, **kwargs): + """Create Conv layer (same as disable_weight_init)""" + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + """ + Select appropriate operations class for model. + + NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3). + LEGACY: All other paths unchanged for backward compatibility. + + Args: + weight_dtype: Weight storage dtype + compute_dtype: Computation dtype + load_device: Device for loading + disable_fast_fp8: Disable fast FP8 paths + fp8_optimizations: Enable FP8 optimizations + scaled_fp8: Legacy FP8 dtype marker + model_config: Model config object (optional, for mixed precision support) + + Returns: + Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init) + """ + # NEW: Check for mixed precision + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: + MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") + return MixedPrecisionOps + + # LEGACY paths (unchanged) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) @@ -503,3 +953,35 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ return disable_weight_init return manual_cast + + +# ============================================================================== +# Register built-in quantization formats +# ============================================================================== + +# FP8 E4M3FN weight-only quantization +QuantFormatRegistry.register( + "fp8_e4m3fn_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=False, + use_fp8_matmul=True, +) + +# FP8 E4M3FN weight+activation quantization +QuantFormatRegistry.register( + "fp8_e4m3fn_scaled_dynamic", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=True, + use_fp8_matmul=True, +) + +# FP8 E5M2 weight-only quantization +QuantFormatRegistry.register( + "fp8_e5m2_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e5m2, + quantize_activation=False, + use_fp8_matmul=True, +) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb110d..e4bd7451429b 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -50,6 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None scaled_fp8 = None + layer_quant_config = None # Per-layer quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod From 388294677ee8a97a6fe4fd65e7de436da62f45f3 Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 22 Oct 2025 10:30:00 +0200 Subject: [PATCH 40/49] Updated design using Tensor Subclasses --- comfy/model_detection.py | 4 +- comfy/ops.py | 514 ++++-------------- comfy/quant_ops.py | 346 ++++++++++++ tests-unit/comfy_test/test_mixed_precision.py | 274 ++++++++++ tests-unit/comfy_test/test_quant_detection.py | 262 +++++++++ tests-unit/comfy_test/test_quant_registry.py | 399 ++++++++++++++ 6 files changed, 1400 insertions(+), 399 deletions(-) create mode 100644 comfy/quant_ops.py create mode 100644 tests-unit/comfy_test/test_mixed_precision.py create mode 100644 tests-unit/comfy_test/test_quant_detection.py create mode 100644 tests-unit/comfy_test/test_quant_registry.py diff --git a/comfy/model_detection.py b/comfy/model_detection.py index c7ef7dab6e02..7a3851228b3a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -107,10 +107,10 @@ def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): if weight_key in state_dict: weight_dtype = state_dict[weight_key].dtype if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled" + format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2" layer_configs[layer_name] = { "format": format_name, - "params": {"use_fp8_matmul": True} + "params": {} } elif key.endswith(".weight") and not key.endswith(".scale_weight"): layer_name = normalize_layer_name(key, known_prefixes) diff --git a/comfy/ops.py b/comfy/ops.py index fac2be7282da..6afbc2cff0e6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -344,6 +344,21 @@ class Embedding(disable_weight_init.Embedding): def fp8_linear(self, input): + """ + Legacy FP8 linear function - now uses tensor subclass infrastructure. + + This function maintains backward compatibility with existing code while + routing all FP8 computation through the unified tensor subclass system. + All actual FP8 matmul logic is handled by the registered operation handlers + in quant_ops.py via __torch_dispatch__. + + Args: + self: Linear layer with FP8 weight and scale parameters + input: Input tensor (any dtype) + + Returns: + Output tensor or None if weight is not FP8 + """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None @@ -355,10 +370,12 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype + if len(input.shape) == 3: + # Get weight and bias using standard casting w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - w = w.t() + # Get scales (same as before) scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: @@ -368,23 +385,31 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - - if bias is not None: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + + # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch + quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) + + # Handle input quantization and wrapping + if self.scale_input is None: + # Clamp input to FP8 range and quantize + input = torch.clamp(input, min=-448, max=448, out=input) + input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) - - if isinstance(o, tuple): - o = o[0] - + # Apply inverse scale and quantize + input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() + + # Wrap input in QuantizedTensorFP8 + quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype) + + # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! + # This is the key unification: all FP8 computation goes through one path + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + + # Reshape output if tensor_2d: return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None @@ -479,307 +504,8 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -# ============================================================================== -# Quantization Format Registry System -# ============================================================================== - -class QuantFormatHandler: - """ - Base class for all quantization format handlers. - - A handler encapsulates the logic for a specific quantization format - (e.g., FP8 scaled, MX formats) and manages the quantization - parameters and forward pass for quantized layers. - """ - - def __init__(self, layer, **config): - """ - Initialize handler for a specific layer. - - Args: - layer: The nn.Module layer (Linear, Conv2d, etc.) - **config: Format-specific configuration - """ - self.layer = layer - self.config = config - - def setup_parameters(self): - """ - Initialize quantization parameters on the layer. - Called during layer construction or load_state_dict. - - Subclasses should create parameters like scale_weight, scale_input, etc. - and attach them to self.layer. - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()") - - def forward(self, *args, **kwargs): - """ - Execute quantized forward pass. - - Signature matches the layer's expected forward pass. - Handler accesses layer parameters via self.layer (weight, bias, etc.) - - Args: - *args: Positional arguments matching layer forward signature - **kwargs: Keyword arguments matching layer forward signature - - Returns: - Layer output tensor - - Examples: - Linear: forward(input) - Conv2d: forward(input) - GroupNorm: forward(input) - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement forward()") - - def load_state_dict(self, state_dict, prefix): - """ - Load quantization parameters from state dict. - - Args: - state_dict: State dictionary - prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.") - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()") - - def state_dict(self, prefix): - """ - Save quantization parameters to state dict. - - Args: - prefix: Key prefix for this layer - - Returns: - Dictionary of quantization parameters with full keys - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()") - - def convert_weight(self, weight, inplace=False): - """ - Convert weight from quantized to full precision (dequantize). - - Args: - weight: Quantized weight tensor - inplace: Whether to modify in-place - - Returns: - Dequantized weight tensor - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()") - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): - """ - Convert and set weight from full precision to quantized. - - Args: - weight: Full precision weight tensor - inplace_update: Whether to update layer weight in-place - seed: Random seed for stochastic rounding - return_weight: If True, return quantized weight without setting - - Returns: - Quantized weight if return_weight=True, else None - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()") - - -class QuantFormatRegistry: - """ - Global registry for quantization formats. - - Formats are registered with a unique name and handler class. - Custom formats can be registered by custom nodes. - """ - - _formats = {} - - @classmethod - def register(cls, name, handler_class, **default_config): - """ - Register a new quantization format. - - Args: - name: Unique format identifier (e.g., "fp8_e4m3fn_scaled") - handler_class: Handler class implementing QuantFormatHandler - **default_config: Default configuration parameters - - Example: - QuantFormatRegistry.register( - "fp8_e4m3fn_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=False, - use_fp8_matmul=True, - ) - """ - if not issubclass(handler_class, QuantFormatHandler): - raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}") - - cls._formats[name] = { - "handler": handler_class, - "config": default_config.copy() - } - logging.debug(f"Registered quantization format: {name}") - - @classmethod - def get(cls, name, **override_config): - """ - Get format info with optional config overrides. - - Args: - name: Format identifier - **override_config: Configuration overrides - - Returns: - Dict with 'handler' (class) and 'config' (dict) keys - - Raises: - ValueError: If format name not registered - """ - if name not in cls._formats: - available = ", ".join(cls._formats.keys()) if cls._formats else "none" - raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}") - - format_info = cls._formats[name].copy() - # Merge override_config into default config - config = format_info["config"].copy() - config.update(override_config) - format_info["config"] = config - return format_info - - @classmethod - def list_formats(cls): - """List all registered format names""" - return list(cls._formats.keys()) - - @classmethod - def is_registered(cls, name): - """Check if a format is registered""" - return name in cls._formats - - -class FP8ScaledHandler(QuantFormatHandler): - """ - Handler for FP8 quantization with per-tensor scaling. - - Supports both weight-only and weight+activation quantization. - Compatible with existing fp8_linear implementation. - """ - - def setup_parameters(self): - """Initialize scale_weight and optionally scale_input""" - device = self.layer.weight.device - dtype = torch.float32 - - # Always have scale_weight for FP8 - if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None: - self.layer.scale_weight = torch.nn.Parameter( - torch.ones((), device=device, dtype=dtype), - requires_grad=False - ) - - # scale_input is optional (for activation quantization) - if self.config.get("quantize_activation", False): - if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None: - self.layer.scale_input = torch.nn.Parameter( - torch.ones((), device=device, dtype=dtype), - requires_grad=False - ) - else: - self.layer.scale_input = None - - def forward(self, *args, **kwargs): - """ - FP8 forward pass with optional activation quantization. - Supports Linear layers (Conv2d in future). - """ - # Detect layer type and dispatch - if isinstance(self.layer, torch.nn.Linear): - return self._forward_linear(*args, **kwargs) - else: - raise NotImplementedError( - f"FP8ScaledHandler not implemented for {type(self.layer).__name__}" - ) - - def _forward_linear(self, input): - """FP8 forward for Linear layers""" - # Try fast path with fp8_linear if enabled - if self.config.get("use_fp8_matmul", False) and not self.layer.training: - try: - result = fp8_linear(self.layer, input) - if result is not None: - return result - except Exception as e: - logging.debug(f"FP8 matmul failed, falling back to standard path: {e}") - - # Standard path: dequantize and compute - weight, bias = cast_bias_weight(self.layer, input) - - # Dequantize weight - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - - # Apply weight functions (LoRA, etc.) - they see dequantized weights - if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0: - weight = weight * scale - for f in self.layer.weight_function: - weight = f(weight) - else: - weight = weight * scale - - if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0: - for f in self.layer.bias_function: - bias = f(bias) if bias is not None else None - - # Execute linear operation - # Optimization: multiply by scale on smaller tensor - if weight.numel() < input.numel() and len(self.layer.weight_function) == 0: - return torch.nn.functional.linear(input, weight, bias) - else: - return torch.nn.functional.linear(input, weight, bias) - - def load_state_dict(self, state_dict, prefix): - """Load scale parameters from state dict""" - scale_weight_key = f"{prefix}scale_weight" - if scale_weight_key in state_dict: - self.layer.scale_weight.data.copy_(state_dict[scale_weight_key]) - - scale_input_key = f"{prefix}scale_input" - if scale_input_key in state_dict and self.layer.scale_input is not None: - self.layer.scale_input.data.copy_(state_dict[scale_input_key]) - - def state_dict(self, prefix): - """Save scale parameters to state dict""" - result = {f"{prefix}scale_weight": self.layer.scale_weight} - if self.layer.scale_input is not None: - result[f"{prefix}scale_input"] = self.layer.scale_input - return result - - def convert_weight(self, weight, inplace=False): - """Dequantize: multiply by scale""" - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - if inplace: - weight *= scale - return weight - return weight * scale - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): - """Quantize: divide by scale with stochastic rounding""" - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - quantized = comfy.float.stochastic_rounding( - weight / scale, - self.layer.weight.dtype, - seed=seed - ) - - if return_weight: - return quantized - - if inplace_update: - self.layer.weight.data.copy_(quantized) - else: - self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False) +# Import quantization operations from separate module +from .quant_ops import QuantizedTensorFP8 # ============================================================================== @@ -800,12 +526,13 @@ class MixedPrecisionOps(disable_weight_init): _layer_quant_config = {} # Class variable set by pick_operations() class Linear(disable_weight_init.Linear): - """Linear layer with optional per-layer quantization""" + """Linear layer with optional per-layer quantization using tensor subclasses""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.quant_handler = None - self._handler_initialized = False + self.quant_format = None + self.quant_scale = None + self._quantization_initialized = False def reset_parameters(self): # Don't allocate weights - return None like disable_weight_init @@ -815,9 +542,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """ Called by PyTorch during load_state_dict. - This is where we initialize the handler since we now know the layer name. + Load weight and wrap in QuantizedTensorFP8 if this layer is quantized. """ - if not self._handler_initialized: + # Call parent to load weight and bias first + super()._load_from_state_dict( + state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs + ) + + # After weight is loaded, wrap it if this layer is quantized + if not self._quantization_initialized: # Normalize layer name from prefix layer_name = prefix.rstrip('.') @@ -828,60 +562,78 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, break # Check if this layer has quantization config - # Access via parent class since _layer_quant_config is a class variable if layer_name in MixedPrecisionOps._layer_quant_config: config = MixedPrecisionOps._layer_quant_config[layer_name] - try: - format_info = QuantFormatRegistry.get( - config["format"], - **config.get("params", {}) - ) + self.quant_format = config.get("format", "fp8_e4m3fn") + + # Load scale parameter + scale_key = f"{prefix}scale_weight" + if scale_key in state_dict: + self.quant_scale = state_dict[scale_key] - # Initialize handler - self.quant_handler = format_info["handler"](self, **format_info["config"]) - self.quant_handler.setup_parameters() - - # Let handler load its parameters (scale_weight, etc.) - self.quant_handler.load_state_dict(state_dict, prefix) - - logging.debug(f"Initialized {config['format']} handler for layer {layer_name}") - except ValueError as e: - # Format not registered - fall back to standard precision - logging.warning( - f"Quantization format '{config['format']}' not registered for layer {layer_name}. " - f"Falling back to standard precision. Error: {e}" - ) - self.quant_handler = None - except Exception as e: - logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}") - self.quant_handler = None + # Wrap weight in QuantizedTensorFP8 + if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + try: + # Determine original dtype (default to bfloat16) + orig_dtype = torch.bfloat16 + + # Wrap weight in quantized tensor subclass + quantized_weight = QuantizedTensorFP8( + self.weight.data, + self.quant_scale, + orig_dtype=orig_dtype + ) + + # Replace weight parameter with wrapped version + self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + + logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})") + except Exception as e: + logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}") + self.quant_format = None + self.quant_scale = None + else: + logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization") + self.quant_format = None + self.quant_scale = None + else: + logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict") + self.quant_format = None - self._handler_initialized = True - - # Call parent to load weight and bias - super()._load_from_state_dict( - state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs - ) + self._quantization_initialized = True def _save_to_state_dict(self, destination, prefix, keep_vars): - """Save layer parameters including quantization metadata""" - super()._save_to_state_dict(destination, prefix, keep_vars) - - # Save handler parameters (scale_weight, etc.) - if self.quant_handler: - handler_dict = self.quant_handler.state_dict(prefix) - destination.update(handler_dict) + """Save layer parameters including quantization scale""" + # First unwrap the weight if it's quantized + if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8): + # Temporarily unwrap to save the raw FP8 data + quantized_tensor = self.weight.data + raw_fp8_data = quantized_tensor._raw_data + original_weight = self.weight + self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False) + + # Call parent to save unwrapped weight + super()._save_to_state_dict(destination, prefix, keep_vars) + + # Restore the wrapped weight + self.weight = original_weight + + # Save the scale parameter + if self.quant_scale is not None: + destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach() + else: + # Standard path for non-quantized weights + super()._save_to_state_dict(destination, prefix, keep_vars) def forward_comfy_cast_weights(self, input): - """Forward pass with optional quantization""" - if self.quant_handler: - # Use handler for quantized forward - return self.quant_handler.forward(input) - else: - # Standard path for non-quantized layers - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + """ + Forward pass - tensor subclass handles dispatch automatically! + __torch_dispatch__ will route to registered handlers based on tensor types. + """ + weight, bias = cast_bias_weight(self, input) + + # Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it! + return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): """Main forward pass""" @@ -953,35 +705,3 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ return disable_weight_init return manual_cast - - -# ============================================================================== -# Register built-in quantization formats -# ============================================================================== - -# FP8 E4M3FN weight-only quantization -QuantFormatRegistry.register( - "fp8_e4m3fn_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=False, - use_fp8_matmul=True, -) - -# FP8 E4M3FN weight+activation quantization -QuantFormatRegistry.register( - "fp8_e4m3fn_scaled_dynamic", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=True, - use_fp8_matmul=True, -) - -# FP8 E5M2 weight-only quantization -QuantFormatRegistry.register( - "fp8_e5m2_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e5m2, - quantize_activation=False, - use_fp8_matmul=True, -) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py new file mode 100644 index 000000000000..681eb9134935 --- /dev/null +++ b/comfy/quant_ops.py @@ -0,0 +1,346 @@ +import torch +import logging + +# ============================================================================== +# Global Operation Registry +# ============================================================================== + +# Global operation registry: torch operation → handler function +_QUANT_OP_REGISTRY = {} + +def register_quant_op(torch_op): + """ + Decorator to register an operation handler. + + Example: + @register_quant_op(torch.ops.aten.linear.default) + def handle_linear_fp8(func, args, kwargs): + # Implementation + ... + """ + def decorator(handler_func): + _QUANT_OP_REGISTRY[torch_op] = handler_func + return handler_func + return decorator + + +def get_quant_handler(torch_op): + """Get registered handler for an operation""" + return _QUANT_OP_REGISTRY.get(torch_op) + + +def list_registered_ops(): + """List all registered quantized operations""" + return list(_QUANT_OP_REGISTRY.keys()) + + +# ============================================================================== +# comfy_kitchen Integration +# ============================================================================== + +try: + import comfy_kitchen as ck + ck.disable_backend("cutile") + _CK_AVAILABLE = True + logging.info("comfy_kitchen available for optimized quantization kernels") +except ImportError: + ck = None + _CK_AVAILABLE = False + logging.info("comfy_kitchen not available - using PyTorch fallbacks") +except Exception as e: + ck = None + _CK_AVAILABLE = False + logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") + + +# ============================================================================== +# Quantized Tensor Subclass +# ============================================================================== + +class QuantizedTensorFP8(torch.Tensor): + """ + Tensor subclass for FP8 quantized data. + Automatically handles operations via __torch_dispatch__. + """ + + @staticmethod + def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16): + """ + Create a quantized FP8 tensor. + + Args: + tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2) + scale: Scale factor for dequantization (scalar tensor) + orig_dtype: Original dtype before quantization + """ + return torch.Tensor._make_subclass(cls, tensor, require_grad=False) + + def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): + self._scale = scale + self._orig_dtype = orig_dtype + # Store a reference to prevent infinite recursion in dequantize + self._raw_data = tensor + + def __repr__(self): + return (f"QuantizedTensorFP8(shape={self.shape}, " + f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + """ + Intercept ALL torch operations. + Routes to registered handlers or falls back to dequantization. + """ + kwargs = kwargs or {} + + # Special case: skip dispatch for internal tensor operations + # that are used for unwrapping (to avoid recursion) + if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]: + # For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach + if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8): + # Special handling for detach - return a new QuantizedTensorFP8 + qt = args[0] + detached_data = qt._raw_data.detach() + return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype) + + # For other ops, just unwrap + def unwrap(arg): + if isinstance(arg, QuantizedTensorFP8): + return arg._raw_data + return arg + new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args) + return func(*new_args, **kwargs) + + # Look up registered handler for this operation + handler = _QUANT_OP_REGISTRY.get(func) + if handler: + return handler(func, args, kwargs) + + # No handler - dequantize and use standard path + return cls._dequant_and_fallback(func, args, kwargs) + + @classmethod + def _dequant_and_fallback(cls, func, args, kwargs): + """Fallback: dequantize all quantized tensors""" + def dequant_arg(arg): + if isinstance(arg, QuantizedTensorFP8): + return arg.dequantize() + elif isinstance(arg, (list, tuple)): + return type(arg)(dequant_arg(a) for a in arg) + return arg + + new_args = dequant_arg(args) + new_kwargs = dequant_arg(kwargs) + return func(*new_args, **new_kwargs) + + def dequantize(self) -> torch.Tensor: + """Explicit dequantization""" + # Use the raw data and convert directly + # Call aten ops directly to minimize dispatch interference + plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) + # Multiply by scale + return plain_tensor * self._scale + + def detach(self): + """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" + # Detach the raw data and create a new QuantizedTensorFP8 + detached_data = self._raw_data.detach() + return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) + + +# ============================================================================== +# Operation Handlers for Quantized Tensors +# ============================================================================== + +@register_quant_op(torch.ops.aten.linear.default) +def handle_linear_fp8(func, args, kwargs): + """ + Handle F.linear() with quantized inputs. + + Supports: + - QuantizedTensorFP8 input + QuantizedTensorFP8 weight + - QuantizedTensorFP8 input + regular weight + - Regular input + QuantizedTensorFP8 weight + """ + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + # Case 1: Both input and weight are FP8 + if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): + # Use _scaled_mm for FP8×FP8 matmul + # Get plain tensors to avoid dispatch recursion + plain_input = input_tensor._raw_data + plain_weight = weight._raw_data + weight_t = plain_weight.t().contiguous() + + try: + if bias is not None: + output = torch._scaled_mm( + plain_input, + weight_t, + out_dtype=input_tensor._orig_dtype, + bias=bias, + scale_a=input_tensor._scale, + scale_b=weight._scale + ) + else: + output = torch._scaled_mm( + plain_input, + weight_t, + out_dtype=input_tensor._orig_dtype, + scale_a=input_tensor._scale, + scale_b=weight._scale + ) + + if isinstance(output, tuple): + output = output[0] + + # Check if output is FP8 (some architectures support this) + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + # Keep quantized! + output_scale = input_tensor._scale * weight._scale + return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) + else: + return output + except Exception as e: + logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + # Fall through to dequantization path + + # Case 2: Only weight is quantized + if isinstance(weight, QuantizedTensorFP8): + weight_dq = weight.dequantize() + input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor + return torch.nn.functional.linear(input_dq, weight_dq, bias) + + # Case 3: Only input is quantized + elif isinstance(input_tensor, QuantizedTensorFP8): + input_dq = input_tensor.dequantize() + return torch.nn.functional.linear(input_dq, weight, bias) + + # Case 4: Neither is quantized (shouldn't happen, but handle it) + else: + return torch.nn.functional.linear(input_tensor, weight, bias) + + +@register_quant_op(torch.ops.aten.silu.default) +def handle_silu_fp8(func, args, kwargs): + """ + SiLU can be computed approximately on FP8. + Keeps activations quantized for next layer. + """ + input_q = args[0] + + if not isinstance(input_q, QuantizedTensorFP8): + # Not quantized, use standard path + return torch.nn.functional.silu(input_q) + + # Compute SiLU while keeping quantized + # SiLU(x) = x * sigmoid(x) + + # Get plain tensor to avoid dispatch recursion + plain_tensor = input_q._raw_data + + # Upcast to FP16 for sigmoid stability + x_fp16 = plain_tensor.to(torch.float16) + sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale) + result_fp16 = x_fp16 * sigmoid_fp16 + + # Convert back to FP8 + result_fp8 = result_fp16.to(plain_tensor.dtype) + + # Return quantized (scale approximately preserved) + return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype) + + +@register_quant_op(torch.ops.aten.layer_norm.default) +def handle_layernorm_fp8(func, args, kwargs): + """ + LayerNorm requires high precision. + Dequantizes input and returns standard tensor. + """ + input_q = args[0] + normalized_shape = args[1] + weight = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + # Dequantize if needed + if isinstance(input_q, QuantizedTensorFP8): + x = input_q.dequantize() + else: + x = input_q + + # Standard LayerNorm + result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps) + + # Return dequantized (next layer will quantize if needed) + return result + + +@register_quant_op(torch.ops.aten.group_norm.default) +def handle_groupnorm_fp8(func, args, kwargs): + """ + GroupNorm requires high precision. + Dequantizes input and returns standard tensor. + """ + input_q = args[0] + num_groups = args[1] + weight = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + # Dequantize if needed + if isinstance(input_q, QuantizedTensorFP8): + x = input_q.dequantize() + else: + x = input_q + + # Standard GroupNorm + result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps) + + # Return dequantized + return result + + +@register_quant_op(torch.ops.aten.add.Tensor) +def handle_add_fp8(func, args, kwargs): + """ + Handle addition with mixed quantized/non-quantized tensors. + """ + a = args[0] + b = args[1] + + # If both are quantized, dequantize both + if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): + return a.dequantize() + b.dequantize() + # If only one is quantized, dequantize it + elif isinstance(a, QuantizedTensorFP8): + return a.dequantize() + b + elif isinstance(b, QuantizedTensorFP8): + return a + b.dequantize() + # Neither is quantized + else: + return a + b + + +@register_quant_op(torch.ops.aten.mul.Tensor) +def handle_mul_fp8(func, args, kwargs): + """ + Handle multiplication with mixed quantized/non-quantized tensors. + """ + a = args[0] + b = args[1] + + # If both are quantized, dequantize both + if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): + return a.dequantize() * b.dequantize() + # If only one is quantized, dequantize it + elif isinstance(a, QuantizedTensorFP8): + return a.dequantize() * b + elif isinstance(b, QuantizedTensorFP8): + return a * b.dequantize() + # Neither is quantized + else: + return a * b + diff --git a/tests-unit/comfy_test/test_mixed_precision.py b/tests-unit/comfy_test/test_mixed_precision.py new file mode 100644 index 000000000000..cbfa2866da4d --- /dev/null +++ b/tests-unit/comfy_test/test_mixed_precision.py @@ -0,0 +1,274 @@ +""" +End-to-end tests for mixed precision quantization. +Tests Phase 3: Mixed Precision Operations +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import ops + + +class SimpleModel(torch.nn.Module): + """Simple model for testing mixed precision""" + def __init__(self, operations=ops.disable_weight_init): + super().__init__() + self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + + def forward(self, x): + x = self.layer1(x) + x = torch.nn.functional.relu(x) + x = self.layer2(x) + x = torch.nn.functional.relu(x) + x = self.layer3(x) + return x + + +class TestMixedPrecisionOps(unittest.TestCase): + """Test MixedPrecisionOps end-to-end""" + + def test_all_layers_standard(self): + """Test that model with no quantization works normally""" + # Configure no quantization + ops.MixedPrecisionOps._layer_quant_config = {} + + # Create model + model = SimpleModel(operations=ops.MixedPrecisionOps) + + # Initialize weights manually + model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) + model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) + model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) + model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) + model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) + model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) + + # Initialize weight_function and bias_function + for layer in [model.layer1, model.layer2, model.layer3]: + layer.weight_function = [] + layer.bias_function = [] + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_mixed_precision_load(self): + """Test loading a mixed precision model from state dict""" + # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} # Disable for CPU testing + }, + "layer3": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict with mixed precision + fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2) + + state_dict = { + # Layer 1: FP8 E4M3FN + "layer1.weight": fp8_weight1, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + + # Layer 2: Standard BF16 + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + + # Layer 3: FP8 E5M2 + "layer3.weight": fp8_weight3, + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + "layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32), + } + + # Create model and load state dict + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Verify handlers are set up correctly + self.assertIsNotNone(model.layer1.quant_handler) + self.assertIsNone(model.layer2.quant_handler) # No quantization + self.assertIsNotNone(model.layer3.quant_handler) + + # Verify scales were loaded + self.assertEqual(model.layer1.scale_weight.item(), 2.0) + self.assertEqual(model.layer3.scale_weight.item(), 1.5) + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_state_dict_round_trip(self): + """Test saving and loading state dict preserves quantization""" + # Configure mixed precision + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict1 = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model1 = SimpleModel(operations=ops.MixedPrecisionOps) + model1.load_state_dict(state_dict1) + + # Save state dict + state_dict2 = model1.state_dict() + + # Verify scale_weight is saved + self.assertIn("layer1.scale_weight", state_dict2) + self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0) + + # Load into new model + model2 = SimpleModel(operations=ops.MixedPrecisionOps) + model2.load_state_dict(state_dict2) + + # Verify handler is set up + self.assertIsNotNone(model2.layer1.quant_handler) + self.assertEqual(model2.layer1.scale_weight.item(), 3.0) + + # Verify forward passes match + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output1 = model1(input_tensor) + output2 = model2(input_tensor) + + torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3) + + def test_weight_function_compatibility(self): + """Test that weight_function (LoRA) works with quantized layers""" + # Configure FP8 quantization + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Add a weight function (simulating LoRA) + # LoRA delta must match weight shape (20, 10) + def apply_lora(weight): + # Generate LoRA delta matching weight shape + lora_delta = torch.randn_like(weight) * 0.01 + return weight + lora_delta + + model.layer1.weight_function.append(apply_lora) + + # Forward pass should work with LoRA + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_error_handling_unknown_format(self): + """Test that unknown formats fall back gracefully""" + # Configure with unknown format + layer_quant_config = { + "layer1": { + "format": "unknown_format_xyz", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict + state_dict = { + "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + # Load should not crash, just log warning + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Handler should be None (fallback to standard) + self.assertIsNone(model.layer1.quant_handler) + + # Forward pass should still work + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + self.assertEqual(output.shape, (5, 40)) + + +class TestPickOperationsWithMixedPrecision(unittest.TestCase): + """Test pick_operations with mixed precision config""" + + def test_pick_operations_with_layer_quant_config(self): + """Test that pick_operations returns MixedPrecisionOps when config present""" + from comfy import supported_models_base + + # Create model config with layer_quant_config + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = { + "layer1": {"format": "fp8_e4m3fn_scaled", "params": {}} + } + + result = ops.pick_operations(None, None, model_config=model_config) + + self.assertEqual(result, ops.MixedPrecisionOps) + self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config) + + def test_pick_operations_without_layer_quant_config(self): + """Test that pick_operations falls back to standard when no config""" + from comfy import supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = None + + result = ops.pick_operations(None, None, model_config=model_config) + + self.assertEqual(result, ops.disable_weight_init) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_test/test_quant_detection.py b/tests-unit/comfy_test/test_quant_detection.py new file mode 100644 index 000000000000..bb952a81b3b2 --- /dev/null +++ b/tests-unit/comfy_test/test_quant_detection.py @@ -0,0 +1,262 @@ +""" +Integration tests for quantization detection. +Tests Phase 2: Detection & Integration +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import model_detection + + +class TestNormalizeLayerName(unittest.TestCase): + """Test the normalize_layer_name helper function""" + + def test_strip_prefix_and_suffix(self): + """Test stripping prefix and suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.weight", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_strip_multiple_prefixes(self): + """Test with multiple known prefixes""" + known_prefixes = ["model.diffusion_model.", "model.model.", "net."] + + result1 = model_detection.normalize_layer_name( + "model.diffusion_model.block.attn.weight", + known_prefixes + ) + self.assertEqual(result1, "block.attn") + + result2 = model_detection.normalize_layer_name( + "model.model.encoder.layer.weight", + known_prefixes + ) + self.assertEqual(result2, "encoder.layer") + + result3 = model_detection.normalize_layer_name( + "net.transformer.blocks.0.weight", + known_prefixes + ) + self.assertEqual(result3, "transformer.blocks.0") + + def test_strip_scale_weight_suffix(self): + """Test stripping scale_weight suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.scale_weight", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_strip_bias_suffix(self): + """Test stripping bias suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.bias", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_no_prefix_match(self): + """Test with no prefix match""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "other.model.layer1.weight", + known_prefixes + ) + # Should strip suffix but not prefix + self.assertEqual(result, "other.model.layer1") + + +class TestDetectLayerQuantization(unittest.TestCase): + """Test the detect_layer_quantization function""" + + def test_no_quantization(self): + """Test with no quantization markers""" + state_dict = { + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + "model.diffusion_model.layer2.weight": torch.randn(20, 30), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + self.assertIsNone(result) + + def test_legacy_scaled_fp8(self): + """Test that legacy scaled_fp8 marker returns None""" + # Create FP8 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn), + "model.diffusion_model.layer1.weight": fp8_weight, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + # Should return None to trigger legacy path + self.assertIsNone(result) + + def test_metadata_format(self): + """Test with new metadata format""" + metadata = { + "format_version": "1.0", + "layers": { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": True} + }, + "layer2": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": True} + } + } + } + state_dict = { + "model.diffusion_model._quantization_metadata": metadata, + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertIn("layer2", result) + self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") + self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled") + # Metadata should be popped from state_dict + self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict) + + def test_mixed_precision_detection(self): + """Test detection of mixed precision via scale patterns""" + # Create FP8 tensors by converting from float32 + fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + # Layer 1: FP8 (has scale_weight) + "model.diffusion_model.layer1.weight": fp8_weight1, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + # Layer 2: Standard (no scale_weight) + "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + # Layer 3: FP8 (has scale_weight) + "model.diffusion_model.layer3.weight": fp8_weight3, + "model.diffusion_model.layer3.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertIn("layer3", result) + self.assertNotIn("layer2", result) # Layer 2 not quantized + self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") + self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled") + + def test_all_layers_quantized(self): + """Test that uniform quantization (all layers) returns None""" + # Create FP8 tensors by converting from float32 + fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + # All layers have scale_weight + "model.diffusion_model.layer1.weight": fp8_weight1, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + "model.diffusion_model.layer2.weight": fp8_weight2, + "model.diffusion_model.layer2.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + # If all layers are quantized, it's not mixed precision + # Should return None to use legacy scaled_fp8_ops path + self.assertIsNone(result) + + def test_fp8_e5m2_detection(self): + """Test detection of FP8 E5M2 format""" + # Create FP8 E5M2 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2) + state_dict = { + "model.diffusion_model.layer1.weight": fp8_weight, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled") + + def test_invalid_metadata(self): + """Test with invalid metadata format""" + state_dict = { + "model.diffusion_model._quantization_metadata": "invalid_string", + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + # Should return None on invalid metadata + self.assertIsNone(result) + + def test_different_prefix(self): + """Test with different model prefix (audio model)""" + # Create FP8 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "model.model.layer1.weight": fp8_weight, + "model.model.layer1.scale_weight": torch.tensor(1.0), + "model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + } + result = model_detection.detect_layer_quantization(state_dict, "model.model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + + +class TestPickOperationsIntegration(unittest.TestCase): + """Test pick_operations with model_config parameter""" + + def test_backward_compatibility(self): + """Test that pick_operations works without model_config (legacy)""" + from comfy import ops + + # Should work without model_config parameter + result = ops.pick_operations(None, None) + self.assertIsNotNone(result) + self.assertEqual(result, ops.disable_weight_init) + + def test_with_model_config_no_quant(self): + """Test with model_config but no quantization""" + from comfy import ops, supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = None + + result = ops.pick_operations(None, None, model_config=model_config) + self.assertIsNotNone(result) + # Should use standard path + self.assertEqual(result, ops.disable_weight_init) + + def test_legacy_scaled_fp8(self): + """Test that legacy scaled_fp8 still works""" + from comfy import ops, supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.scaled_fp8 = torch.float8_e4m3fn + + result = ops.pick_operations( + None, None, + scaled_fp8=torch.float8_e4m3fn, + model_config=model_config + ) + self.assertIsNotNone(result) + # Should return scaled_fp8_ops (the returned class is the inner class) + # Check that it's not the standard disable_weight_init + self.assertNotEqual(result, ops.disable_weight_init) + # Verify it has Linear class + self.assertTrue(hasattr(result, 'Linear')) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_test/test_quant_registry.py b/tests-unit/comfy_test/test_quant_registry.py new file mode 100644 index 000000000000..5c624b1db9d8 --- /dev/null +++ b/tests-unit/comfy_test/test_quant_registry.py @@ -0,0 +1,399 @@ +""" +Unit tests for tensor subclass quantization system. +Tests the new QuantizedTensorFP8 subclass and operation handlers. +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import ops +from comfy import quant_ops + + +class TestQuantizedTensorFP8(unittest.TestCase): + """Test the QuantizedTensorFP8 tensor subclass""" + + def test_creation(self): + """Test creating a QuantizedTensorFP8""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) + + self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._scale, scale) + self.assertEqual(qt._orig_dtype, torch.bfloat16) + + def test_dequantize(self): + """Test explicit dequantization""" + # Create a simple FP8 tensor + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32) + dequantized = qt.dequantize() + + # Dequantized should be approximately ones * 3.0 + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_repr(self): + """Test string representation""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.5) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) + repr_str = repr(qt) + + self.assertIn("QuantizedTensorFP8", repr_str) + self.assertIn("shape", repr_str) + self.assertIn("scale", repr_str) + + +class TestOperationRegistry(unittest.TestCase): + """Test the operation registry system""" + + def test_registry_basics(self): + """Test that operations are registered""" + registered_ops = quant_ops.list_registered_ops() + + # Check that key operations are registered + self.assertIn(torch.ops.aten.linear.default, registered_ops) + self.assertIn(torch.ops.aten.silu.default, registered_ops) + self.assertIn(torch.ops.aten.layer_norm.default, registered_ops) + self.assertIn(torch.ops.aten.add.Tensor, registered_ops) + self.assertIn(torch.ops.aten.mul.Tensor, registered_ops) + + def test_get_handler(self): + """Test getting a registered handler""" + handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default) + self.assertIsNotNone(handler) + self.assertTrue(callable(handler)) + + def test_custom_registration(self): + """Test registering a custom operation""" + + # Define a custom handler + @quant_ops.register_quant_op(torch.ops.aten.relu.default) + def custom_relu_handler(func, args, kwargs): + return func(*args, **kwargs) + + # Verify registration + handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default) + self.assertIsNotNone(handler) + self.assertEqual(handler, custom_relu_handler) + + +class TestLinearHandler(unittest.TestCase): + """Test the linear operation handler""" + + def test_linear_with_quantized_weight(self): + """Test F.linear with quantized weight""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized weight + weight_fp32 = torch.randn(256, 128, dtype=torch.float32) + scale = torch.tensor(2.0) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) + + # Create input + input_tensor = torch.randn(16, 128, dtype=torch.float32) + + # Call linear (should trigger dispatch) + output = torch.nn.functional.linear(input_tensor, weight_q, bias=None) + + # Verify output shape + self.assertEqual(output.shape, (16, 256)) + + # Verify it's approximately correct (allowing for FP8 quantization error) + # Note: FP8 has limited precision, so use very loose tolerance + expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None) + # Just check that it's in the right ballpark (within 50% error on average) + mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large") + + def test_linear_with_bias(self): + """Test F.linear with quantized weight and bias""" + weight_fp32 = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) + + input_tensor = torch.randn(8, 32, dtype=torch.float32) + bias = torch.randn(64, dtype=torch.float32) + + output = torch.nn.functional.linear(input_tensor, weight_q, bias) + + self.assertEqual(output.shape, (8, 64)) + + +class TestActivationHandlers(unittest.TestCase): + """Test activation function handlers""" + + def test_silu_with_quantized_input(self): + """Test SiLU with quantized input""" + # Create quantized input + input_fp32 = torch.randn(16, 128, dtype=torch.float32) + scale = torch.tensor(1.0) + input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) + input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) + + # Apply SiLU + output = torch.nn.functional.silu(input_q) + + # Should return a QuantizedTensorFP8 + self.assertIsInstance(output, quant_ops.QuantizedTensorFP8) + + # Verify approximate correctness + expected = torch.nn.functional.silu(input_fp32) + output_dq = output.dequantize() + self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2)) + + def test_layernorm_dequantizes(self): + """Test that LayerNorm dequantizes input""" + # Create quantized input + input_fp32 = torch.randn(16, 128, dtype=torch.float32) + scale = torch.tensor(1.0) + input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) + input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) + + # Apply LayerNorm + weight = torch.ones(128) + bias = torch.zeros(128) + output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias) + + # Should NOT be quantized (LayerNorm breaks quantization) + self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8) + self.assertEqual(output.dtype, torch.float32) + + +class TestElementwiseHandlers(unittest.TestCase): + """Test element-wise operation handlers""" + + def test_add_mixed_tensors(self): + """Test addition with mixed quantized/non-quantized tensors""" + # Create quantized tensor + a_fp32 = torch.ones(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) + + # Non-quantized tensor + b = torch.ones(10, 20, dtype=torch.float32) * 2.0 + + # Add them + result = a_q + b + + # Should be dequantized + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_mul_quantized_tensors(self): + """Test multiplication of two quantized tensors""" + a_fp32 = torch.ones(10, 20) * 2.0 + scale_a = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32) + + b_fp32 = torch.ones(10, 20) * 3.0 + scale_b = torch.tensor(1.0) + b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn) + b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32) + + result = a_q * b_q + + # Should be dequantized + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +class TestMixedPrecisionOps(unittest.TestCase): + """Test MixedPrecisionOps integration""" + + def test_linear_layer_creation(self): + """Test that MixedPrecisionOps.Linear can be created""" + layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32) + + self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear) + self.assertFalse(layer._quantization_initialized) + self.assertIsNone(layer.quant_format) + + def test_layer_quant_config_detection(self): + """Test that layer quantization config is detected during load""" + # Set up layer config + ops.MixedPrecisionOps._layer_quant_config = { + "test_layer": { + "format": "fp8_e4m3fn", + "params": {} + } + } + + # Create a state dict with quantized weight + weight_fp32 = torch.randn(256, 128, dtype=torch.float32) + scale = torch.tensor(2.0) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + + state_dict = { + "model.diffusion_model.test_layer.weight": weight_fp8, + "model.diffusion_model.test_layer.scale_weight": scale, + } + + # Create layer and load + layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn) + layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn)) + + # Manually call _load_from_state_dict + layer._load_from_state_dict( + state_dict, + prefix="model.diffusion_model.test_layer.", + local_metadata={}, + strict=True, + missing_keys=[], + unexpected_keys=[], + error_msgs=[] + ) + + # Verify quantization was initialized + self.assertTrue(layer._quantization_initialized) + self.assertEqual(layer.quant_format, "fp8_e4m3fn") + self.assertIsNotNone(layer.quant_scale) + + # Verify weight is wrapped + self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8) + + # Clean up + ops.MixedPrecisionOps._layer_quant_config = {} + + +class TestBackwardCompatibility(unittest.TestCase): + """Test backward compatibility with legacy systems""" + + def test_legacy_ops_classes_exist(self): + """Test that legacy ops classes still exist""" + self.assertTrue(hasattr(ops, 'disable_weight_init')) + self.assertTrue(hasattr(ops, 'manual_cast')) + self.assertTrue(hasattr(ops, 'fp8_ops')) + self.assertTrue(hasattr(ops, 'scaled_fp8_ops')) + + def test_pick_operations_legacy_path(self): + """Test pick_operations returns correct class for legacy cases""" + # Test standard case + result = ops.pick_operations(torch.float32, torch.float32) + self.assertEqual(result, ops.disable_weight_init) + + # Test manual cast case + result = ops.pick_operations(torch.float32, torch.float16) + self.assertEqual(result, ops.manual_cast) + + +class TestFP8LinearUnification(unittest.TestCase): + """Test that fp8_linear now uses the unified tensor subclass infrastructure""" + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8") + def test_fp8_linear_uses_tensor_subclass(self): + """Verify fp8_linear wraps tensors in QuantizedTensorFP8""" + torch.manual_seed(42) + + # Create a mock Linear layer with FP8 weight + linear = ops.fp8_ops.Linear(4, 3, bias=True) + linear.weight = torch.nn.Parameter( + torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn), + requires_grad=False + ) + linear.bias = torch.nn.Parameter( + torch.randn(3, dtype=torch.bfloat16), + requires_grad=False + ) + linear.scale_weight = torch.tensor(1.0) + linear.scale_input = None # No input scaling + + # Create input + input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) + + # Call fp8_linear - should work without errors + try: + result = ops.fp8_linear(linear, input_tensor) + self.assertIsNotNone(result) + self.assertEqual(result.shape, (2, 3)) + except Exception as e: + # On CPU or unsupported hardware, _scaled_mm might not be available + # but the function should still complete without syntax errors + pass + + def test_fp8_linear_maintains_signature(self): + """Verify fp8_linear maintains its original function signature""" + import inspect + sig = inspect.signature(ops.fp8_linear) + params = list(sig.parameters.keys()) + + # Should have 'self' and 'input' parameters + self.assertIn('self', params) + self.assertIn('input', params) + self.assertEqual(len(params), 2) + + def test_fp8_linear_returns_none_for_non_fp8(self): + """Verify fp8_linear returns None for non-FP8 weights""" + # Create a Linear layer with BF16 weight (not FP8) + linear = ops.disable_weight_init.Linear(4, 3, bias=False) + linear.weight = torch.nn.Parameter( + torch.randn(3, 4, dtype=torch.bfloat16), + requires_grad=False + ) + + input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) + + # Should return None for non-FP8 weights + result = ops.fp8_linear(linear, input_tensor) + self.assertIsNone(result) + + def test_fp8_ops_linear_uses_fp8_linear(self): + """Verify fp8_ops.Linear still uses fp8_linear in forward pass""" + linear = ops.fp8_ops.Linear(4, 3, bias=False) + + # Verify the class has the forward_comfy_cast_weights method + self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights')) + + # The forward_comfy_cast_weights should attempt to call fp8_linear + # (we can't easily test this without mocking, but we verify structure) + import inspect + source = inspect.getsource(linear.forward_comfy_cast_weights) + self.assertIn('fp8_linear', source) + + +if __name__ == "__main__": + unittest.main() From 19ce6b056d40ddfa0a337e0c0b7b2db29d258c92 Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 22 Oct 2025 11:25:39 +0200 Subject: [PATCH 41/49] Fix FP8 MM --- comfy/ops.py | 14 +--- comfy/quant_ops.py | 205 +++++++++++---------------------------------- 2 files changed, 48 insertions(+), 171 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 6afbc2cff0e6..3e4588706237 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -390,19 +390,7 @@ def fp8_linear(self, input): # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) - - # Handle input quantization and wrapping - if self.scale_input is None: - # Clamp input to FP8 range and quantize - input = torch.clamp(input, min=-448, max=448, out=input) - input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous() - else: - # Apply inverse scale and quantize - input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - - # Wrap input in QuantizedTensorFP8 - quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype) - + quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! # This is the key unification: all FP8 computation goes through one path o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 681eb9134935..8e3bacbaf8af 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -79,18 +79,47 @@ def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): self._scale = scale self._orig_dtype = orig_dtype # Store a reference to prevent infinite recursion in dequantize - self._raw_data = tensor + self._raw_data = tensor.contiguous() def __repr__(self): return (f"QuantizedTensorFP8(shape={self.shape}, " f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + @classmethod + def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + + tensor_fp8 = None + if _CK_AVAILABLE: + try: + tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) + except Exception as e: + logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}") + + if tensor_fp8 is None: + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + + return cls(tensor_fp8, scale, orig_dtype=orig_dtype) + + @classmethod + def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn): + if strategy == "amax": + scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max + scale = scale.to(tensor.device, dtype=torch.float32) + else: + raise ValueError(f"Unknown quantization strategy: {strategy}. " + f"Supported: 'amax'") + + return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype) + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - """ - Intercept ALL torch operations. - Routes to registered handlers or falls back to dequantization. - """ kwargs = kwargs or {} # Special case: skip dispatch for internal tensor operations @@ -134,16 +163,11 @@ def dequant_arg(arg): return func(*new_args, **new_kwargs) def dequantize(self) -> torch.Tensor: - """Explicit dequantization""" - # Use the raw data and convert directly - # Call aten ops directly to minimize dispatch interference plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) - # Multiply by scale return plain_tensor * self._scale def detach(self): """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" - # Detach the raw data and create a new QuantizedTensorFP8 detached_data = self._raw_data.detach() return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) @@ -165,48 +189,35 @@ def handle_linear_fp8(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - + out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype) + # Case 1: Both input and weight are FP8 if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): - # Use _scaled_mm for FP8×FP8 matmul # Get plain tensors to avoid dispatch recursion plain_input = input_tensor._raw_data plain_weight = weight._raw_data - weight_t = plain_weight.t().contiguous() + weight_t = plain_weight.t() # Keep as column-major for cuBLASLt try: - if bias is not None: - output = torch._scaled_mm( - plain_input, - weight_t, - out_dtype=input_tensor._orig_dtype, - bias=bias, - scale_a=input_tensor._scale, - scale_b=weight._scale - ) - else: - output = torch._scaled_mm( - plain_input, - weight_t, - out_dtype=input_tensor._orig_dtype, - scale_a=input_tensor._scale, - scale_b=weight._scale - ) - + output = torch._scaled_mm( + plain_input, + weight_t, + bias=bias, + scale_a=input_tensor._scale, + scale_b=weight._scale, + out_dtype=out_dtype, + ) if isinstance(output, tuple): output = output[0] - # Check if output is FP8 (some architectures support this) if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - # Keep quantized! output_scale = input_tensor._scale * weight._scale - return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) + return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it else: return output except Exception as e: logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - # Fall through to dequantization path - + # Case 2: Only weight is quantized if isinstance(weight, QuantizedTensorFP8): weight_dq = weight.dequantize() @@ -222,125 +233,3 @@ def handle_linear_fp8(func, args, kwargs): else: return torch.nn.functional.linear(input_tensor, weight, bias) - -@register_quant_op(torch.ops.aten.silu.default) -def handle_silu_fp8(func, args, kwargs): - """ - SiLU can be computed approximately on FP8. - Keeps activations quantized for next layer. - """ - input_q = args[0] - - if not isinstance(input_q, QuantizedTensorFP8): - # Not quantized, use standard path - return torch.nn.functional.silu(input_q) - - # Compute SiLU while keeping quantized - # SiLU(x) = x * sigmoid(x) - - # Get plain tensor to avoid dispatch recursion - plain_tensor = input_q._raw_data - - # Upcast to FP16 for sigmoid stability - x_fp16 = plain_tensor.to(torch.float16) - sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale) - result_fp16 = x_fp16 * sigmoid_fp16 - - # Convert back to FP8 - result_fp8 = result_fp16.to(plain_tensor.dtype) - - # Return quantized (scale approximately preserved) - return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype) - - -@register_quant_op(torch.ops.aten.layer_norm.default) -def handle_layernorm_fp8(func, args, kwargs): - """ - LayerNorm requires high precision. - Dequantizes input and returns standard tensor. - """ - input_q = args[0] - normalized_shape = args[1] - weight = args[2] if len(args) > 2 else None - bias = args[3] if len(args) > 3 else None - eps = args[4] if len(args) > 4 else 1e-5 - - # Dequantize if needed - if isinstance(input_q, QuantizedTensorFP8): - x = input_q.dequantize() - else: - x = input_q - - # Standard LayerNorm - result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps) - - # Return dequantized (next layer will quantize if needed) - return result - - -@register_quant_op(torch.ops.aten.group_norm.default) -def handle_groupnorm_fp8(func, args, kwargs): - """ - GroupNorm requires high precision. - Dequantizes input and returns standard tensor. - """ - input_q = args[0] - num_groups = args[1] - weight = args[2] if len(args) > 2 else None - bias = args[3] if len(args) > 3 else None - eps = args[4] if len(args) > 4 else 1e-5 - - # Dequantize if needed - if isinstance(input_q, QuantizedTensorFP8): - x = input_q.dequantize() - else: - x = input_q - - # Standard GroupNorm - result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps) - - # Return dequantized - return result - - -@register_quant_op(torch.ops.aten.add.Tensor) -def handle_add_fp8(func, args, kwargs): - """ - Handle addition with mixed quantized/non-quantized tensors. - """ - a = args[0] - b = args[1] - - # If both are quantized, dequantize both - if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): - return a.dequantize() + b.dequantize() - # If only one is quantized, dequantize it - elif isinstance(a, QuantizedTensorFP8): - return a.dequantize() + b - elif isinstance(b, QuantizedTensorFP8): - return a + b.dequantize() - # Neither is quantized - else: - return a + b - - -@register_quant_op(torch.ops.aten.mul.Tensor) -def handle_mul_fp8(func, args, kwargs): - """ - Handle multiplication with mixed quantized/non-quantized tensors. - """ - a = args[0] - b = args[1] - - # If both are quantized, dequantize both - if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): - return a.dequantize() * b.dequantize() - # If only one is quantized, dequantize it - elif isinstance(a, QuantizedTensorFP8): - return a.dequantize() * b - elif isinstance(b, QuantizedTensorFP8): - return a * b.dequantize() - # Neither is quantized - else: - return a * b - From b6e0a53c1158090096657913eb447d6ab22d90f9 Mon Sep 17 00:00:00 2001 From: lspindler Date: Fri, 24 Oct 2025 14:44:54 +0200 Subject: [PATCH 42/49] An actually functional POC --- comfy/model_detection.py | 123 ++-------- comfy/ops.py | 268 ++++++++------------- comfy/quant_ops.py | 494 ++++++++++++++++++++++++++++----------- comfy/sd.py | 8 +- 4 files changed, 468 insertions(+), 425 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7a3851228b3a..feab164c6b86 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -7,121 +7,24 @@ import torch -# ============================================================================== -# Quantization Detection Functions -# ============================================================================== - -def normalize_layer_name(full_key, known_prefixes): - """ - Strip model prefix and parameter suffix from a state dict key. - - Args: - full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight") - known_prefixes: List of known model prefixes to strip - - Returns: - Normalized layer name (e.g., "layer1") - """ - name = full_key - - # Strip model prefix - for prefix in known_prefixes: - if name.startswith(prefix): - name = name[len(prefix):] - break - - # Remove parameter suffix - for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]: - if name.endswith(suffix): - name = name[:-len(suffix)] - break - - return name - - -def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): - """ - Detect per-layer quantization configuration from state dict. - - Detection priority: - 1. Check for _quantization_metadata key (new format) - 2. Check for scaled_fp8 key (legacy format - return None) - 3. Check for per-layer scale_weight patterns (mixed detection) - 4. No quantization detected (return None) - - Args: - state_dict: Model state dictionary - prefix: Key prefix for model layers - - Returns: - Dict mapping layer names to quantization configs, or None for legacy/no quantization. - - Example return value: - { - "input_blocks.5.1.transformer_blocks.0.attn1.to_q": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": True} - }, - "middle_block.1.transformer_blocks.0.attn2.to_k": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": True} - } - } - """ - - # 1. Check for new metadata format - metadata_key = f"{prefix}_quantization_metadata" - if metadata_key in state_dict: - try: - metadata = state_dict.pop(metadata_key) - if isinstance(metadata, dict) and "layers" in metadata: - logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})") - return metadata["layers"] - else: - logging.warning(f"Invalid quantization metadata format, ignoring") - except Exception as e: - logging.error(f"Failed to parse quantization metadata: {e}") - return None +def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."): + # 1. Check for per-layer config in metadata + quant_key = "_quantization_metadata" + if metadata is not None and quant_key in metadata: + quant_metadata = metadata.pop(quant_key) + quant_metadata = json.loads(quant_metadata) + if isinstance(quant_metadata, dict) and "layers" in quant_metadata: + logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") + return quant_metadata["layers"] + else: + raise ValueError(f"Invalid quantization metadata format") # 2. Check for legacy scaled_fp8 marker - # If present, return None to use legacy code path scaled_fp8_key = f"{prefix}scaled_fp8" if scaled_fp8_key in state_dict: logging.debug("Detected legacy scaled_fp8 format, using legacy code path") return None - - # 3. Check for per-layer scale patterns (mixed precision without metadata) - # Look for layers that have scale_weight but not all layers have it - known_prefixes = [prefix] - layer_configs = {} - layers_with_scale = set() - layers_with_weight = set() - - for key in state_dict.keys(): - if key.startswith(prefix): - if key.endswith(".scale_weight"): - layer_name = normalize_layer_name(key, known_prefixes) - layers_with_scale.add(layer_name) - # Detect format based on weight dtype - weight_key = f"{prefix}{layer_name}.weight" - if weight_key in state_dict: - weight_dtype = state_dict[weight_key].dtype - if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2" - layer_configs[layer_name] = { - "format": format_name, - "params": {} - } - elif key.endswith(".weight") and not key.endswith(".scale_weight"): - layer_name = normalize_layer_name(key, known_prefixes) - layers_with_weight.add(layer_name) - - # If we found scale_weight on some but not all layers, it's mixed precision - if layer_configs and len(layers_with_scale) < len(layers_with_weight): - logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers") - return layer_configs - - # 4. No quantization detected + return None @@ -821,7 +724,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.optimizations["fp8"] = True # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix) + layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix) if layer_quant_config: model_config.layer_quant_config = layer_quant_config logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") diff --git a/comfy/ops.py b/comfy/ops.py index 3e4588706237..4e24b25de2fe 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -345,19 +345,8 @@ class Embedding(disable_weight_init.Embedding): def fp8_linear(self, input): """ - Legacy FP8 linear function - now uses tensor subclass infrastructure. - - This function maintains backward compatibility with existing code while - routing all FP8 computation through the unified tensor subclass system. - All actual FP8 matmul logic is handled by the registered operation handlers - in quant_ops.py via __torch_dispatch__. - - Args: - self: Linear layer with FP8 weight and scale parameters - input: Input tensor (any dtype) - - Returns: - Output tensor or None if weight is not FP8 + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: @@ -372,10 +361,8 @@ def fp8_linear(self, input): input_dtype = input.dtype if len(input.shape) == 3: - # Get weight and bias using standard casting w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - # Get scales (same as before) scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: @@ -388,14 +375,13 @@ def fp8_linear(self, input): else: scale_input = scale_input.to(input.device) - # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch - quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) - quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) - # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! - # This is the key unification: all FP8 computation goes through one path + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - # Reshape output if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -492,183 +478,117 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -# Import quantization operations from separate module -from .quant_ops import QuantizedTensorFP8 - - # ============================================================================== # Mixed Precision Operations # ============================================================================== +from .quant_ops import QuantizedTensor, TensorCoreFP8Layout + +QUANT_FORMAT_MIXINS = { + "float8_e4m3fn": { + "dtype": torch.float8_e4m3fn, + "layout_type": TensorCoreFP8Layout, + "parameters": { + "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + } + } +} class MixedPrecisionOps(disable_weight_init): - """ - Operations class supporting per-layer quantization (mixed precision). - - This class enables different layers to use different quantization formats - within the same model (e.g., some layers FP8, others BF16). - - Layer-specific quantization is configured via _layer_quant_config class variable, - which is set by pick_operations() when a model has mixed precision. - """ - - _layer_quant_config = {} # Class variable set by pick_operations() - - class Linear(disable_weight_init.Linear): - """Linear layer with optional per-layer quantization using tensor subclasses""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.quant_format = None - self.quant_scale = None - self._quantization_initialized = False - + _layer_quant_config = {} + _compute_dtype = torch.bfloat16 + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.tensor_class = None + def reset_parameters(self): - # Don't allocate weights - return None like disable_weight_init return None def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - """ - Called by PyTorch during load_state_dict. - Load weight and wrap in QuantizedTensorFP8 if this layer is quantized. - """ - # Call parent to load weight and bias first - super()._load_from_state_dict( - state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs - ) - - # After weight is loaded, wrap it if this layer is quantized - if not self._quantization_initialized: - # Normalize layer name from prefix - layer_name = prefix.rstrip('.') - - # Strip known model prefixes - for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: - if layer_name.startswith(model_prefix): - layer_name = layer_name[len(model_prefix):] - break - - # Check if this layer has quantization config - if layer_name in MixedPrecisionOps._layer_quant_config: - config = MixedPrecisionOps._layer_quant_config[layer_name] - self.quant_format = config.get("format", "fp8_e4m3fn") - - # Load scale parameter - scale_key = f"{prefix}scale_weight" - if scale_key in state_dict: - self.quant_scale = state_dict[scale_key] - - # Wrap weight in QuantizedTensorFP8 - if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - try: - # Determine original dtype (default to bfloat16) - orig_dtype = torch.bfloat16 - - # Wrap weight in quantized tensor subclass - quantized_weight = QuantizedTensorFP8( - self.weight.data, - self.quant_scale, - orig_dtype=orig_dtype - ) - - # Replace weight parameter with wrapped version - self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - - logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})") - except Exception as e: - logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}") - self.quant_format = None - self.quant_scale = None - else: - logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization") - self.quant_format = None - self.quant_scale = None - else: - logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict") - self.quant_format = None - - self._quantization_initialized = True - - def _save_to_state_dict(self, destination, prefix, keep_vars): - """Save layer parameters including quantization scale""" - # First unwrap the weight if it's quantized - if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8): - # Temporarily unwrap to save the raw FP8 data - quantized_tensor = self.weight.data - raw_fp8_data = quantized_tensor._raw_data - original_weight = self.weight - self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False) - - # Call parent to save unwrapped weight - super()._save_to_state_dict(destination, prefix, keep_vars) + + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") - # Restore the wrapped weight - self.weight = original_weight + mixin = QUANT_FORMAT_MIXINS[quant_format] + self.layout_type = mixin["layout_type"] - # Save the scale parameter - if self.quant_scale is not None: - destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach() - else: - # Standard path for non-quantized weights - super()._save_to_state_dict(destination, prefix, keep_vars) - + layout_params = { + 'scale': state_dict.pop(f"{prefix}weight_scale", None), + 'orig_dtype': MixedPrecisionOps._compute_dtype + } + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), + requires_grad=False + ) + + for param_name, param_value in mixin["parameters"].items(): + _v = state_dict.pop(f"{prefix}{param_name}", None) + if _v is None: + continue + setattr(self, param_name, _v.to(device=device)) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + def forward_comfy_cast_weights(self, input): - """ - Forward pass - tensor subclass handles dispatch automatically! - __torch_dispatch__ will route to registered handlers based on tensor types. - """ weight, bias = cast_bias_weight(self, input) - - # Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it! - return torch.nn.functional.linear(input, weight, bias) - - def forward(self, *args, **kwargs): - """Main forward pass""" + self._forward(input, weight, bias) + + def forward(self, input, *args, **kwargs): run_every_op() - # Same logic as disable_weight_init.Linear + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) - @classmethod - def conv_nd(s, dims, *args, **kwargs): - """Create Conv layer (same as disable_weight_init)""" - if dims == 2: - return s.Conv2d(*args, **kwargs) - elif dims == 3: - return s.Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): - """ - Select appropriate operations class for model. - - NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3). - LEGACY: All other paths unchanged for backward compatibility. - - Args: - weight_dtype: Weight storage dtype - compute_dtype: Computation dtype - load_device: Device for loading - disable_fast_fp8: Disable fast FP8 paths - fp8_optimizations: Enable FP8 optimizations - scaled_fp8: Legacy FP8 dtype marker - model_config: Model config object (optional, for mixed precision support) - - Returns: - Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init) - """ - # NEW: Check for mixed precision + # If model_config.layer_quant_config exists, use new MixedPrecisionOps. if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + # MixedPrecisionOps._compute_dtype = compute_dtype # TODO + MixedPrecisionOps._compute_dtype = torch.bfloat16 logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps - # LEGACY paths (unchanged) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8e3bacbaf8af..3802da8524e7 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,42 +1,79 @@ import torch import logging +from typing import Tuple, Dict -# ============================================================================== -# Global Operation Registry -# ============================================================================== +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} -# Global operation registry: torch operation → handler function -_QUANT_OP_REGISTRY = {} -def register_quant_op(torch_op): +def register_layout_op(torch_op, layout_type): """ - Decorator to register an operation handler. + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) + Example: + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation + ... + """ + def decorator(handler_func): + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func + return handler_func + return decorator + + +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) Example: - @register_quant_op(torch.ops.aten.linear.default) - def handle_linear_fp8(func, args, kwargs): - # Implementation + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout ... """ def decorator(handler_func): - _QUANT_OP_REGISTRY[torch_op] = handler_func + _GENERIC_UTILS[torch_op] = handler_func return handler_func return decorator -def get_quant_handler(torch_op): - """Get registered handler for an operation""" - return _QUANT_OP_REGISTRY.get(torch_op) +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None -def list_registered_ops(): - """List all registered quantized operations""" - return list(_QUANT_OP_REGISTRY.keys()) +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params -# ============================================================================== -# comfy_kitchen Integration -# ============================================================================== +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + try: import comfy_kitchen as ck @@ -53,106 +90,144 @@ def list_registered_ops(): logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") -# ============================================================================== -# Quantized Tensor Subclass -# ============================================================================== +class QuantizedLayout: + """ + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") -class QuantizedTensorFP8(torch.Tensor): + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError(f"TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): """ - Tensor subclass for FP8 quantized data. - Automatically handles operations via __torch_dispatch__. + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ @staticmethod - def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16): + def __new__(cls, qdata, layout_type, layout_params): """ - Create a quantized FP8 tensor. + Create a quantized tensor. Args: - tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2) - scale: Scale factor for dequantization (scalar tensor) - orig_dtype: Original dtype before quantization + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_subclass(cls, tensor, require_grad=False) + return torch.Tensor._make_subclass(cls, qdata, require_grad=False) - def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): - self._scale = scale - self._orig_dtype = orig_dtype - # Store a reference to prevent infinite recursion in dequantize - self._raw_data = tensor.contiguous() + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata.contiguous() + self._layout_type = layout_type + self._layout_params = layout_params def __repr__(self): - return (f"QuantizedTensorFP8(shape={self.shape}, " - f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + layout_name = self._layout_type.__name__ + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - @classmethod - def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): - orig_dtype = tensor.dtype + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_q_data"] + ctx = { + "layout_type": self._layout_type, + } - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v - tensor_fp8 = None - if _CK_AVAILABLE: - try: - tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) - except Exception as e: - logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}") + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params - if tensor_fp8 is None: - lp_amax = torch.finfo(fp8_dtype).max - tensor_scaled = tensor.float() / scale - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] - return cls(tensor_fp8, scale, orig_dtype=orig_dtype) + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) @classmethod - def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn): - if strategy == "amax": - scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max - scale = scale.to(tensor.device, dtype=torch.float32) - else: - raise ValueError(f"Unknown quantization strategy: {strategy}. " - f"Supported: 'amax'") - - return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype) + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) + + def dequantize(self) -> torch.Tensor: + return self._layout_type.dequantize(self._qdata, **self._layout_params) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} - # Special case: skip dispatch for internal tensor operations - # that are used for unwrapping (to avoid recursion) - if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]: - # For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach - if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8): - # Special handling for detach - return a new QuantizedTensorFP8 - qt = args[0] - detached_data = qt._raw_data.detach() - return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype) - - # For other ops, just unwrap - def unwrap(arg): - if isinstance(arg, QuantizedTensorFP8): - return arg._raw_data - return arg - new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args) - return func(*new_args, **kwargs) + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) - # Look up registered handler for this operation - handler = _QUANT_OP_REGISTRY.get(func) - if handler: - return handler(func, args, kwargs) + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) - # No handler - dequantize and use standard path + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) @classmethod def _dequant_and_fallback(cls, func, args, kwargs): - """Fallback: dequantize all quantized tensors""" def dequant_arg(arg): - if isinstance(arg, QuantizedTensorFP8): + if isinstance(arg, QuantizedTensor): return arg.dequantize() elif isinstance(arg, (list, tuple)): return type(arg)(dequant_arg(a) for a in arg) @@ -161,75 +236,220 @@ def dequant_arg(arg): new_args = dequant_arg(args) new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) + + +# ============================================================================== +# Generic Utilities (Layout-Agnostic Operations) +# ============================================================================== + +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_dtype is not None and target_dtype != qt.dtype: + logging.warning( + f"QuantizedTensor: dtype conversion requested to {target_dtype}, " + f"but not supported for quantized tensors. Ignoring dtype." + ) - def dequantize(self) -> torch.Tensor: - plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) - return plain_tensor * self._scale + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt - def detach(self): - """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" - detached_data = self._raw_data.detach() - return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata) + qt_dest._layout_type = src._layout_type + qt_dest._layout_params = _copy_layout_params(src._layout_params) + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True # ============================================================================== -# Operation Handlers for Quantized Tensors +# FP8 Layout + Operation Handlers # ============================================================================== - -@register_quant_op(torch.ops.aten.linear.default) -def handle_linear_fp8(func, args, kwargs): +class TensorCoreFP8Layout(QuantizedLayout): """ - Handle F.linear() with quantized inputs. - - Supports: - - QuantizedTensorFP8 input + QuantizedTensorFP8 weight - - QuantizedTensorFP8 input + regular weight - - Regular input + QuantizedTensorFP8 weight + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) """ + @classmethod + def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if scale is None: + scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + if _CK_AVAILABLE and tensor.device.type == "cuda": + qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) + else: + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return qdata, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + return plain_tensor * scale + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + + +@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype) - - # Case 1: Both input and weight are FP8 - if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): - # Get plain tensors to avoid dispatch recursion - plain_input = input_tensor._raw_data - plain_weight = weight._raw_data - weight_t = plain_weight.t() # Keep as column-major for cuBLASLt + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + try: output = torch._scaled_mm( - plain_input, + plain_input.reshape(-1, input_shape[2]), weight_t, bias=bias, - scale_a=input_tensor._scale, - scale_b=weight._scale, + scale_a=scale_a, + scale_b=scale_b, out_dtype=out_dtype, ) - if isinstance(output, tuple): - output = output[0] - + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = input_tensor._scale * weight._scale - return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, TensorCoreFP8Layout, output_params) else: return output + except Exception as e: - logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - # Case 2: Only weight is quantized - if isinstance(weight, QuantizedTensorFP8): - weight_dq = weight.dequantize() - input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor - return torch.nn.functional.linear(input_dq, weight_dq, bias) - - # Case 3: Only input is quantized - elif isinstance(input_tensor, QuantizedTensorFP8): - input_dq = input_tensor.dequantize() - return torch.nn.functional.linear(input_dq, weight, bias) - - # Case 4: Neither is quantized (shouldn't happen, but handle it) - else: - return torch.nn.functional.linear(input_tensor, weight, bias) + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248dae1..b965e98427d9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): weight_dtype = comfy.utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd @@ -1346,8 +1346,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - sd = comfy.utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options) + sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) From 0d201540ca6e1827c7d7267895aa4e6324910a44 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 07:55:44 +0100 Subject: [PATCH 43/49] Remove CK reference and ensure correct compute dtype --- comfy/model_detection.py | 12 ++---------- comfy/ops.py | 4 +--- comfy/quant_ops.py | 26 ++++---------------------- comfy/sd.py | 5 ++++- 4 files changed, 11 insertions(+), 36 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index feab164c6b86..378250e04494 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -7,8 +7,7 @@ import torch -def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."): - # 1. Check for per-layer config in metadata +def detect_layer_quantization(metadata): quant_key = "_quantization_metadata" if metadata is not None and quant_key in metadata: quant_metadata = metadata.pop(quant_key) @@ -18,13 +17,6 @@ def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_mode return quant_metadata["layers"] else: raise ValueError(f"Invalid quantization metadata format") - - # 2. Check for legacy scaled_fp8 marker - scaled_fp8_key = f"{prefix}scaled_fp8" - if scaled_fp8_key in state_dict: - logging.debug("Detected legacy scaled_fp8 format, using legacy code path") - return None - return None @@ -724,7 +716,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.optimizations["fp8"] = True # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix) + layer_quant_config = detect_layer_quantization(metadata) if layer_quant_config: model_config.layer_quant_config = layer_quant_config logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") diff --git a/comfy/ops.py b/comfy/ops.py index 4e24b25de2fe..b46e7553de8b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -581,11 +581,9 @@ def forward(self, input, *args, **kwargs): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): - # If model_config.layer_quant_config exists, use new MixedPrecisionOps. if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config - # MixedPrecisionOps._compute_dtype = compute_dtype # TODO - MixedPrecisionOps._compute_dtype = torch.bfloat16 + MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 3802da8524e7..8d7f6480a31c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -75,21 +75,6 @@ def _copy_layout_params(params): return new_params -try: - import comfy_kitchen as ck - ck.disable_backend("cutile") - _CK_AVAILABLE = True - logging.info("comfy_kitchen available for optimized quantization kernels") -except ImportError: - ck = None - _CK_AVAILABLE = False - logging.info("comfy_kitchen not available - using PyTorch fallbacks") -except Exception as e: - ck = None - _CK_AVAILABLE = False - logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") - - class QuantizedLayout: """ Base class for quantization layouts. @@ -372,13 +357,10 @@ def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - if _CK_AVAILABLE and tensor.device.type == "cuda": - qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) - else: - lp_amax = torch.finfo(fp8_dtype).max - tensor_scaled = tensor.float() / scale - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) layout_params = { 'scale': scale, diff --git a/comfy/sd.py b/comfy/sd.py index b965e98427d9..6411bb27d62e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if hasattr(model_config, "layer_quant_config"): + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) if model_options.get("fp8_optimizations", False): From 77d307049f22816278574d8003df6f7d891894df Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 08:41:23 +0100 Subject: [PATCH 44/49] Update unit tests --- comfy/ops.py | 2 +- comfy/quant_ops.py | 4 +- .../test_mixed_precision.py | 147 +++---- tests-unit/comfy_quant/test_quant_registry.py | 183 ++++++++ tests-unit/comfy_test/test_quant_detection.py | 262 ------------ tests-unit/comfy_test/test_quant_registry.py | 399 ------------------ 6 files changed, 235 insertions(+), 762 deletions(-) rename tests-unit/{comfy_test => comfy_quant}/test_mixed_precision.py (60%) create mode 100644 tests-unit/comfy_quant/test_quant_registry.py delete mode 100644 tests-unit/comfy_test/test_quant_detection.py delete mode 100644 tests-unit/comfy_test/test_quant_registry.py diff --git a/comfy/ops.py b/comfy/ops.py index b46e7553de8b..911228b51dda 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -566,7 +566,7 @@ def _forward(self, input, weight, bias): def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) - self._forward(input, weight, bias) + return self._forward(input, weight, bias) def forward(self, input, *args, **kwargs): run_every_op() diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8d7f6480a31c..96d2fa03fdbd 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -143,7 +143,7 @@ def __tensor_flatten__(self): """ Tensor flattening protocol for proper device movement. """ - inner_tensors = ["_q_data"] + inner_tensors = ["_qdata"] ctx = { "layout_type": self._layout_type, } @@ -206,7 +206,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # Step 3: Fallback to dequantization if isinstance(args[0] if args else None, QuantizedTensor): - logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) @classmethod diff --git a/tests-unit/comfy_test/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py similarity index 60% rename from tests-unit/comfy_test/test_mixed_precision.py rename to tests-unit/comfy_quant/test_mixed_precision.py index cbfa2866da4d..e3455276063d 100644 --- a/tests-unit/comfy_test/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -1,8 +1,3 @@ -""" -End-to-end tests for mixed precision quantization. -Tests Phase 3: Mixed Precision Operations -""" - import unittest import torch import sys @@ -12,10 +7,10 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from comfy import ops +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout class SimpleModel(torch.nn.Module): - """Simple model for testing mixed precision""" def __init__(self, operations=ops.disable_weight_init): super().__init__() self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) @@ -32,8 +27,7 @@ def forward(self, x): class TestMixedPrecisionOps(unittest.TestCase): - """Test MixedPrecisionOps end-to-end""" - + def test_all_layers_standard(self): """Test that model with no quantization works normally""" # Configure no quantization @@ -67,48 +61,54 @@ def test_mixed_precision_load(self): # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} # Disable for CPU testing + "format": "float8_e4m3fn", + "params": {} }, "layer3": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict = { # Layer 1: FP8 E4M3FN "layer1.weight": fp8_weight1, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), # Layer 2: Standard BF16 "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), - # Layer 3: FP8 E5M2 + # Layer 3: FP8 E4M3FN "layer3.weight": fp8_weight3, "layer3.bias": torch.randn(40, dtype=torch.bfloat16), - "layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } - # Create model and load state dict + # Create model and load state dict (strict=False because custom loading pops keys) model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) + + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) - # Verify handlers are set up correctly - self.assertIsNotNone(model.layer1.quant_handler) - self.assertIsNone(model.layer2.quant_handler) # No quantization - self.assertIsNotNone(model.layer3.quant_handler) + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) # Verify scales were loaded - self.assertEqual(model.layer1.scale_weight.item(), 2.0) - self.assertEqual(model.layer3.scale_weight.item(), 1.5) + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -116,13 +116,13 @@ def test_mixed_precision_load(self): self.assertEqual(output.shape, (5, 40)) - def test_state_dict_round_trip(self): - """Test saving and loading state dict preserves quantization""" + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" # Configure mixed precision layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config @@ -132,45 +132,35 @@ def test_state_dict_round_trip(self): state_dict1 = { "layer1.weight": fp8_weight, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model1 = SimpleModel(operations=ops.MixedPrecisionOps) - model1.load_state_dict(state_dict1) + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict1, strict=False) # Save state dict - state_dict2 = model1.state_dict() - - # Verify scale_weight is saved - self.assertIn("layer1.scale_weight", state_dict2) - self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0) - - # Load into new model - model2 = SimpleModel(operations=ops.MixedPrecisionOps) - model2.load_state_dict(state_dict2) - - # Verify handler is set up - self.assertIsNotNone(model2.layer1.quant_handler) - self.assertEqual(model2.layer1.scale_weight.item(), 3.0) + state_dict2 = model.state_dict() - # Verify forward passes match - input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output1 = model1(input_tensor) - output2 = model2(input_tensor) + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) - torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3) + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) def test_weight_function_compatibility(self): """Test that weight_function (LoRA) works with quantized layers""" # Configure FP8 quantization layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config @@ -180,7 +170,7 @@ def test_weight_function_compatibility(self): state_dict = { "layer1.weight": fp8_weight, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), @@ -188,25 +178,24 @@ def test_weight_function_compatibility(self): } model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) - # LoRA delta must match weight shape (20, 10) + # This should trigger dequantization during forward pass def apply_lora(weight): - # Generate LoRA delta matching weight shape lora_delta = torch.randn_like(weight) * 0.01 return weight + lora_delta model.layer1.weight_function.append(apply_lora) - # Forward pass should work with LoRA + # Forward pass should work with LoRA (triggers weight_function path) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) self.assertEqual(output.shape, (5, 40)) def test_error_handling_unknown_format(self): - """Test that unknown formats fall back gracefully""" + """Test that unknown formats raise error""" # Configure with unknown format layer_quant_config = { "layer1": { @@ -226,48 +215,10 @@ def test_error_handling_unknown_format(self): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - # Load should not crash, just log warning + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) - - # Handler should be None (fallback to standard) - self.assertIsNone(model.layer1.quant_handler) - - # Forward pass should still work - input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) - self.assertEqual(output.shape, (5, 40)) - - -class TestPickOperationsWithMixedPrecision(unittest.TestCase): - """Test pick_operations with mixed precision config""" - - def test_pick_operations_with_layer_quant_config(self): - """Test that pick_operations returns MixedPrecisionOps when config present""" - from comfy import supported_models_base - - # Create model config with layer_quant_config - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = { - "layer1": {"format": "fp8_e4m3fn_scaled", "params": {}} - } - - result = ops.pick_operations(None, None, model_config=model_config) - - self.assertEqual(result, ops.MixedPrecisionOps) - self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config) - - def test_pick_operations_without_layer_quant_config(self): - """Test that pick_operations falls back to standard when no config""" - from comfy import supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = None - - result = ops.pick_operations(None, None, model_config=model_config) - - self.assertEqual(result, ops.disable_weight_init) - + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) if __name__ == "__main__": unittest.main() diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000000..263581417177 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,183 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + TensorCoreFP8Layout, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + a_q = QuantizedTensor.from_float( + a_fp32, + TensorCoreFP8Layout, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests-unit/comfy_test/test_quant_detection.py b/tests-unit/comfy_test/test_quant_detection.py deleted file mode 100644 index bb952a81b3b2..000000000000 --- a/tests-unit/comfy_test/test_quant_detection.py +++ /dev/null @@ -1,262 +0,0 @@ -""" -Integration tests for quantization detection. -Tests Phase 2: Detection & Integration -""" - -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from comfy import model_detection - - -class TestNormalizeLayerName(unittest.TestCase): - """Test the normalize_layer_name helper function""" - - def test_strip_prefix_and_suffix(self): - """Test stripping prefix and suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.weight", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_strip_multiple_prefixes(self): - """Test with multiple known prefixes""" - known_prefixes = ["model.diffusion_model.", "model.model.", "net."] - - result1 = model_detection.normalize_layer_name( - "model.diffusion_model.block.attn.weight", - known_prefixes - ) - self.assertEqual(result1, "block.attn") - - result2 = model_detection.normalize_layer_name( - "model.model.encoder.layer.weight", - known_prefixes - ) - self.assertEqual(result2, "encoder.layer") - - result3 = model_detection.normalize_layer_name( - "net.transformer.blocks.0.weight", - known_prefixes - ) - self.assertEqual(result3, "transformer.blocks.0") - - def test_strip_scale_weight_suffix(self): - """Test stripping scale_weight suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.scale_weight", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_strip_bias_suffix(self): - """Test stripping bias suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.bias", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_no_prefix_match(self): - """Test with no prefix match""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "other.model.layer1.weight", - known_prefixes - ) - # Should strip suffix but not prefix - self.assertEqual(result, "other.model.layer1") - - -class TestDetectLayerQuantization(unittest.TestCase): - """Test the detect_layer_quantization function""" - - def test_no_quantization(self): - """Test with no quantization markers""" - state_dict = { - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - "model.diffusion_model.layer2.weight": torch.randn(20, 30), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - self.assertIsNone(result) - - def test_legacy_scaled_fp8(self): - """Test that legacy scaled_fp8 marker returns None""" - # Create FP8 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - "model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn), - "model.diffusion_model.layer1.weight": fp8_weight, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - # Should return None to trigger legacy path - self.assertIsNone(result) - - def test_metadata_format(self): - """Test with new metadata format""" - metadata = { - "format_version": "1.0", - "layers": { - "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": True} - }, - "layer2": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": True} - } - } - } - state_dict = { - "model.diffusion_model._quantization_metadata": metadata, - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertIn("layer2", result) - self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") - self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled") - # Metadata should be popped from state_dict - self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict) - - def test_mixed_precision_detection(self): - """Test detection of mixed precision via scale patterns""" - # Create FP8 tensors by converting from float32 - fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - # Layer 1: FP8 (has scale_weight) - "model.diffusion_model.layer1.weight": fp8_weight1, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - # Layer 2: Standard (no scale_weight) - "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - # Layer 3: FP8 (has scale_weight) - "model.diffusion_model.layer3.weight": fp8_weight3, - "model.diffusion_model.layer3.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertIn("layer3", result) - self.assertNotIn("layer2", result) # Layer 2 not quantized - self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") - self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled") - - def test_all_layers_quantized(self): - """Test that uniform quantization (all layers) returns None""" - # Create FP8 tensors by converting from float32 - fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - # All layers have scale_weight - "model.diffusion_model.layer1.weight": fp8_weight1, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - "model.diffusion_model.layer2.weight": fp8_weight2, - "model.diffusion_model.layer2.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - # If all layers are quantized, it's not mixed precision - # Should return None to use legacy scaled_fp8_ops path - self.assertIsNone(result) - - def test_fp8_e5m2_detection(self): - """Test detection of FP8 E5M2 format""" - # Create FP8 E5M2 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2) - state_dict = { - "model.diffusion_model.layer1.weight": fp8_weight, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled") - - def test_invalid_metadata(self): - """Test with invalid metadata format""" - state_dict = { - "model.diffusion_model._quantization_metadata": "invalid_string", - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - # Should return None on invalid metadata - self.assertIsNone(result) - - def test_different_prefix(self): - """Test with different model prefix (audio model)""" - # Create FP8 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - "model.model.layer1.weight": fp8_weight, - "model.model.layer1.scale_weight": torch.tensor(1.0), - "model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - } - result = model_detection.detect_layer_quantization(state_dict, "model.model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - - -class TestPickOperationsIntegration(unittest.TestCase): - """Test pick_operations with model_config parameter""" - - def test_backward_compatibility(self): - """Test that pick_operations works without model_config (legacy)""" - from comfy import ops - - # Should work without model_config parameter - result = ops.pick_operations(None, None) - self.assertIsNotNone(result) - self.assertEqual(result, ops.disable_weight_init) - - def test_with_model_config_no_quant(self): - """Test with model_config but no quantization""" - from comfy import ops, supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = None - - result = ops.pick_operations(None, None, model_config=model_config) - self.assertIsNotNone(result) - # Should use standard path - self.assertEqual(result, ops.disable_weight_init) - - def test_legacy_scaled_fp8(self): - """Test that legacy scaled_fp8 still works""" - from comfy import ops, supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.scaled_fp8 = torch.float8_e4m3fn - - result = ops.pick_operations( - None, None, - scaled_fp8=torch.float8_e4m3fn, - model_config=model_config - ) - self.assertIsNotNone(result) - # Should return scaled_fp8_ops (the returned class is the inner class) - # Check that it's not the standard disable_weight_init - self.assertNotEqual(result, ops.disable_weight_init) - # Verify it has Linear class - self.assertTrue(hasattr(result, 'Linear')) - - -if __name__ == "__main__": - unittest.main() - diff --git a/tests-unit/comfy_test/test_quant_registry.py b/tests-unit/comfy_test/test_quant_registry.py deleted file mode 100644 index 5c624b1db9d8..000000000000 --- a/tests-unit/comfy_test/test_quant_registry.py +++ /dev/null @@ -1,399 +0,0 @@ -""" -Unit tests for tensor subclass quantization system. -Tests the new QuantizedTensorFP8 subclass and operation handlers. -""" - -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from comfy import ops -from comfy import quant_ops - - -class TestQuantizedTensorFP8(unittest.TestCase): - """Test the QuantizedTensorFP8 tensor subclass""" - - def test_creation(self): - """Test creating a QuantizedTensorFP8""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) - - self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._scale, scale) - self.assertEqual(qt._orig_dtype, torch.bfloat16) - - def test_dequantize(self): - """Test explicit dequantization""" - # Create a simple FP8 tensor - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32) - dequantized = qt.dequantize() - - # Dequantized should be approximately ones * 3.0 - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_repr(self): - """Test string representation""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.5) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) - repr_str = repr(qt) - - self.assertIn("QuantizedTensorFP8", repr_str) - self.assertIn("shape", repr_str) - self.assertIn("scale", repr_str) - - -class TestOperationRegistry(unittest.TestCase): - """Test the operation registry system""" - - def test_registry_basics(self): - """Test that operations are registered""" - registered_ops = quant_ops.list_registered_ops() - - # Check that key operations are registered - self.assertIn(torch.ops.aten.linear.default, registered_ops) - self.assertIn(torch.ops.aten.silu.default, registered_ops) - self.assertIn(torch.ops.aten.layer_norm.default, registered_ops) - self.assertIn(torch.ops.aten.add.Tensor, registered_ops) - self.assertIn(torch.ops.aten.mul.Tensor, registered_ops) - - def test_get_handler(self): - """Test getting a registered handler""" - handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default) - self.assertIsNotNone(handler) - self.assertTrue(callable(handler)) - - def test_custom_registration(self): - """Test registering a custom operation""" - - # Define a custom handler - @quant_ops.register_quant_op(torch.ops.aten.relu.default) - def custom_relu_handler(func, args, kwargs): - return func(*args, **kwargs) - - # Verify registration - handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default) - self.assertIsNotNone(handler) - self.assertEqual(handler, custom_relu_handler) - - -class TestLinearHandler(unittest.TestCase): - """Test the linear operation handler""" - - def test_linear_with_quantized_weight(self): - """Test F.linear with quantized weight""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized weight - weight_fp32 = torch.randn(256, 128, dtype=torch.float32) - scale = torch.tensor(2.0) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) - - # Create input - input_tensor = torch.randn(16, 128, dtype=torch.float32) - - # Call linear (should trigger dispatch) - output = torch.nn.functional.linear(input_tensor, weight_q, bias=None) - - # Verify output shape - self.assertEqual(output.shape, (16, 256)) - - # Verify it's approximately correct (allowing for FP8 quantization error) - # Note: FP8 has limited precision, so use very loose tolerance - expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None) - # Just check that it's in the right ballpark (within 50% error on average) - mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large") - - def test_linear_with_bias(self): - """Test F.linear with quantized weight and bias""" - weight_fp32 = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) - - input_tensor = torch.randn(8, 32, dtype=torch.float32) - bias = torch.randn(64, dtype=torch.float32) - - output = torch.nn.functional.linear(input_tensor, weight_q, bias) - - self.assertEqual(output.shape, (8, 64)) - - -class TestActivationHandlers(unittest.TestCase): - """Test activation function handlers""" - - def test_silu_with_quantized_input(self): - """Test SiLU with quantized input""" - # Create quantized input - input_fp32 = torch.randn(16, 128, dtype=torch.float32) - scale = torch.tensor(1.0) - input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) - input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) - - # Apply SiLU - output = torch.nn.functional.silu(input_q) - - # Should return a QuantizedTensorFP8 - self.assertIsInstance(output, quant_ops.QuantizedTensorFP8) - - # Verify approximate correctness - expected = torch.nn.functional.silu(input_fp32) - output_dq = output.dequantize() - self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2)) - - def test_layernorm_dequantizes(self): - """Test that LayerNorm dequantizes input""" - # Create quantized input - input_fp32 = torch.randn(16, 128, dtype=torch.float32) - scale = torch.tensor(1.0) - input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) - input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) - - # Apply LayerNorm - weight = torch.ones(128) - bias = torch.zeros(128) - output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias) - - # Should NOT be quantized (LayerNorm breaks quantization) - self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8) - self.assertEqual(output.dtype, torch.float32) - - -class TestElementwiseHandlers(unittest.TestCase): - """Test element-wise operation handlers""" - - def test_add_mixed_tensors(self): - """Test addition with mixed quantized/non-quantized tensors""" - # Create quantized tensor - a_fp32 = torch.ones(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) - - # Non-quantized tensor - b = torch.ones(10, 20, dtype=torch.float32) * 2.0 - - # Add them - result = a_q + b - - # Should be dequantized - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_mul_quantized_tensors(self): - """Test multiplication of two quantized tensors""" - a_fp32 = torch.ones(10, 20) * 2.0 - scale_a = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32) - - b_fp32 = torch.ones(10, 20) * 3.0 - scale_b = torch.tensor(1.0) - b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn) - b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32) - - result = a_q * b_q - - # Should be dequantized - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -class TestMixedPrecisionOps(unittest.TestCase): - """Test MixedPrecisionOps integration""" - - def test_linear_layer_creation(self): - """Test that MixedPrecisionOps.Linear can be created""" - layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32) - - self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear) - self.assertFalse(layer._quantization_initialized) - self.assertIsNone(layer.quant_format) - - def test_layer_quant_config_detection(self): - """Test that layer quantization config is detected during load""" - # Set up layer config - ops.MixedPrecisionOps._layer_quant_config = { - "test_layer": { - "format": "fp8_e4m3fn", - "params": {} - } - } - - # Create a state dict with quantized weight - weight_fp32 = torch.randn(256, 128, dtype=torch.float32) - scale = torch.tensor(2.0) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - - state_dict = { - "model.diffusion_model.test_layer.weight": weight_fp8, - "model.diffusion_model.test_layer.scale_weight": scale, - } - - # Create layer and load - layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn) - layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn)) - - # Manually call _load_from_state_dict - layer._load_from_state_dict( - state_dict, - prefix="model.diffusion_model.test_layer.", - local_metadata={}, - strict=True, - missing_keys=[], - unexpected_keys=[], - error_msgs=[] - ) - - # Verify quantization was initialized - self.assertTrue(layer._quantization_initialized) - self.assertEqual(layer.quant_format, "fp8_e4m3fn") - self.assertIsNotNone(layer.quant_scale) - - # Verify weight is wrapped - self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8) - - # Clean up - ops.MixedPrecisionOps._layer_quant_config = {} - - -class TestBackwardCompatibility(unittest.TestCase): - """Test backward compatibility with legacy systems""" - - def test_legacy_ops_classes_exist(self): - """Test that legacy ops classes still exist""" - self.assertTrue(hasattr(ops, 'disable_weight_init')) - self.assertTrue(hasattr(ops, 'manual_cast')) - self.assertTrue(hasattr(ops, 'fp8_ops')) - self.assertTrue(hasattr(ops, 'scaled_fp8_ops')) - - def test_pick_operations_legacy_path(self): - """Test pick_operations returns correct class for legacy cases""" - # Test standard case - result = ops.pick_operations(torch.float32, torch.float32) - self.assertEqual(result, ops.disable_weight_init) - - # Test manual cast case - result = ops.pick_operations(torch.float32, torch.float16) - self.assertEqual(result, ops.manual_cast) - - -class TestFP8LinearUnification(unittest.TestCase): - """Test that fp8_linear now uses the unified tensor subclass infrastructure""" - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8") - def test_fp8_linear_uses_tensor_subclass(self): - """Verify fp8_linear wraps tensors in QuantizedTensorFP8""" - torch.manual_seed(42) - - # Create a mock Linear layer with FP8 weight - linear = ops.fp8_ops.Linear(4, 3, bias=True) - linear.weight = torch.nn.Parameter( - torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn), - requires_grad=False - ) - linear.bias = torch.nn.Parameter( - torch.randn(3, dtype=torch.bfloat16), - requires_grad=False - ) - linear.scale_weight = torch.tensor(1.0) - linear.scale_input = None # No input scaling - - # Create input - input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) - - # Call fp8_linear - should work without errors - try: - result = ops.fp8_linear(linear, input_tensor) - self.assertIsNotNone(result) - self.assertEqual(result.shape, (2, 3)) - except Exception as e: - # On CPU or unsupported hardware, _scaled_mm might not be available - # but the function should still complete without syntax errors - pass - - def test_fp8_linear_maintains_signature(self): - """Verify fp8_linear maintains its original function signature""" - import inspect - sig = inspect.signature(ops.fp8_linear) - params = list(sig.parameters.keys()) - - # Should have 'self' and 'input' parameters - self.assertIn('self', params) - self.assertIn('input', params) - self.assertEqual(len(params), 2) - - def test_fp8_linear_returns_none_for_non_fp8(self): - """Verify fp8_linear returns None for non-FP8 weights""" - # Create a Linear layer with BF16 weight (not FP8) - linear = ops.disable_weight_init.Linear(4, 3, bias=False) - linear.weight = torch.nn.Parameter( - torch.randn(3, 4, dtype=torch.bfloat16), - requires_grad=False - ) - - input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) - - # Should return None for non-FP8 weights - result = ops.fp8_linear(linear, input_tensor) - self.assertIsNone(result) - - def test_fp8_ops_linear_uses_fp8_linear(self): - """Verify fp8_ops.Linear still uses fp8_linear in forward pass""" - linear = ops.fp8_ops.Linear(4, 3, bias=False) - - # Verify the class has the forward_comfy_cast_weights method - self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights')) - - # The forward_comfy_cast_weights should attempt to call fp8_linear - # (we can't easily test this without mocking, but we verify structure) - import inspect - source = inspect.getsource(linear.forward_comfy_cast_weights) - self.assertIn('fp8_linear', source) - - -if __name__ == "__main__": - unittest.main() From e8d267b6605bb553c7c52f61b8777d44895fd24a Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 08:52:50 +0100 Subject: [PATCH 45/49] ruff lint --- comfy/model_base.py | 2 +- comfy/model_detection.py | 2 +- comfy/ops.py | 20 ++--- comfy/quant_ops.py | 76 +++++++++---------- .../comfy_quant/test_mixed_precision.py | 68 ++++++++--------- tests-unit/comfy_quant/test_quant_registry.py | 67 ++++++++-------- 6 files changed, 117 insertions(+), 118 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index e0589ba92095..7c788d085482 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -332,7 +332,7 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_ if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - + # Save mixed precision metadata if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: metadata = { diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 378250e04494..3142a7fc388c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -16,7 +16,7 @@ def detect_layer_quantization(metadata): logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") return quant_metadata["layers"] else: - raise ValueError(f"Invalid quantization metadata format") + raise ValueError("Invalid quantization metadata format") return None diff --git a/comfy/ops.py b/comfy/ops.py index 911228b51dda..a5e9a4d92e50 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -345,7 +345,7 @@ class Embedding(disable_weight_init.Embedding): def fp8_linear(self, input): """ - Legacy FP8 linear function for backward compatibility. + Legacy FP8 linear function for backward compatibility. Uses QuantizedTensor subclass for dispatch. """ dtype = self.weight.dtype @@ -359,7 +359,7 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype - + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) @@ -374,14 +374,14 @@ def fp8_linear(self, input): scale_input = torch.ones((), device=input.device, dtype=torch.float32) else: scale_input = scale_input.to(input.device) - + # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - + if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -523,8 +523,8 @@ def __init__( def reset_parameters(self): return None - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): device = self.factory_kwargs["device"] @@ -540,10 +540,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) if quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - + mixin = QUANT_FORMAT_MIXINS[quant_format] self.layout_type = mixin["layout_type"] - + layout_params = { 'scale': state_dict.pop(f"{prefix}weight_scale", None), 'orig_dtype': MixedPrecisionOps._compute_dtype @@ -578,7 +578,7 @@ def forward(self, input, *args, **kwargs): not isinstance(input, QuantizedTensor)): input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) - + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: @@ -586,7 +586,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps - + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 96d2fa03fdbd..aa1a231bd15c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -31,7 +31,7 @@ def register_generic_util(torch_op): Decorator to register a generic utility that works for all layouts. Args: torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - + Example: @register_generic_util(torch.ops.aten.detach.default) def generic_detach(func, args, kwargs): @@ -78,10 +78,10 @@ def _copy_layout_params(params): class QuantizedLayout: """ Base class for quantization layouts. - + A layout encapsulates the format-specific logic for quantization/dequantization and provides a uniform interface for extracting raw tensors needed for computation. - + New quantization formats should subclass this and implement the required methods. """ @classmethod @@ -90,8 +90,8 @@ def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: @staticmethod def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError(f"TensorLayout must implement dequantize()") - + raise NotImplementedError("TensorLayout must implement dequantize()") + @classmethod def get_plain_tensors(cls, qtensor) -> torch.Tensor: raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") @@ -100,45 +100,45 @@ def get_plain_tensors(cls, qtensor) -> torch.Tensor: class QuantizedTensor(torch.Tensor): """ Universal quantized tensor that works with any layout. - + This tensor subclass uses a pluggable layout system to support multiple quantization formats (FP8, INT4, INT8, etc.) without code duplication. - + The layout_type determines format-specific behavior, while common operations (detach, clone, to) are handled generically. - + Attributes: _qdata: The quantized tensor data _layout_type: Layout class (e.g., TensorCoreFP8Layout) _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ - + @staticmethod def __new__(cls, qdata, layout_type, layout_params): """ Create a quantized tensor. - + Args: qdata: The quantized data tensor layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ return torch.Tensor._make_subclass(cls, qdata, require_grad=False) - + def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata.contiguous() self._layout_type = layout_type self._layout_params = layout_params - + def __repr__(self): layout_name = self._layout_type.__name__ param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - + @property def layout_type(self): return self._layout_type - + def __tensor_flatten__(self): """ Tensor flattening protocol for proper device movement. @@ -147,7 +147,7 @@ def __tensor_flatten__(self): ctx = { "layout_type": self._layout_type, } - + tensor_params = {} non_tensor_params = {} for k, v in self._layout_params.items(): @@ -155,17 +155,17 @@ def __tensor_flatten__(self): tensor_params[k] = v else: non_tensor_params[k] = v - + ctx["tensor_param_keys"] = list(tensor_params.keys()) ctx["non_tensor_params"] = non_tensor_params - + for k, v in tensor_params.items(): attr_name = f"_layout_param_{k}" object.__setattr__(self, attr_name, v) inner_tensors.append(attr_name) - + return inner_tensors, ctx - + @staticmethod def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): """ @@ -174,41 +174,41 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): """ layout_type = ctx["layout_type"] layout_params = dict(ctx["non_tensor_params"]) - + for key in ctx["tensor_param_keys"]: attr_name = f"_layout_param_{key}" layout_params[key] = inner_tensors[attr_name] - + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) - + @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) return cls(qdata, layout_type, layout_params) - + def dequantize(self) -> torch.Tensor: return self._layout_type.dequantize(self._qdata, **self._layout_params) - + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} - + # Step 1: Check generic utilities first (detach, clone, to, etc.) if func in _GENERIC_UTILS: return _GENERIC_UTILS[func](func, args, kwargs) - + # Step 2: Check layout-specific handlers (linear, matmul, etc.) layout_type = _get_layout_from_args(args) if layout_type and func in _LAYOUT_REGISTRY: handler = _LAYOUT_REGISTRY[func].get(layout_type) if handler: return handler(func, args, kwargs) - + # Step 3: Fallback to dequantization if isinstance(args[0] if args else None, QuantizedTensor): logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) - + @classmethod def _dequant_and_fallback(cls, func, args, kwargs): def dequant_arg(arg): @@ -217,7 +217,7 @@ def dequant_arg(arg): elif isinstance(arg, (list, tuple)): return type(arg)(dequant_arg(a) for a in arg) return arg - + new_args = dequant_arg(args) new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) @@ -239,13 +239,13 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= f"QuantizedTensor: dtype conversion requested to {target_dtype}, " f"but not supported for quantized tensors. Ignoring dtype." ) - + if target_layout is not None and target_layout != torch.strided: logging.warning( f"QuantizedTensor: layout change requested to {target_layout}, " f"but not supported. Ignoring layout." ) - + # Handle device transfer current_device = qt._qdata.device if target_device is not None: @@ -254,7 +254,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= target_device = torch.device(target_device) if isinstance(current_device, str): current_device = torch.device(current_device) - + if target_device != current_device: logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") new_q_data = qt._qdata.to(device=target_device) @@ -262,7 +262,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") return new_qt - + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") return qt @@ -318,7 +318,7 @@ def generic_to_dtype_layout(func, args, kwargs): def generic_copy_(func, args, kwargs): qt_dest = args[0] src = args[1] - + if isinstance(qt_dest, QuantizedTensor): if isinstance(src, QuantizedTensor): # Copy from another quantized tensor @@ -383,15 +383,15 @@ def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - + out_dtype = kwargs.get("out_dtype") if out_dtype is None: out_dtype = input_tensor._layout_params['orig_dtype'] - + weight_t = plain_weight.t() tensor_2d = False @@ -424,7 +424,7 @@ def fp8_linear(func, args, kwargs): return QuantizedTensor(output, TensorCoreFP8Layout, output_params) else: return output - + except Exception as e: raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index e3455276063d..1102f9bd4b28 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -16,7 +16,7 @@ def __init__(self, operations=ops.disable_weight_init): self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) - + def forward(self, x): x = self.layer1(x) x = torch.nn.functional.relu(x) @@ -32,10 +32,10 @@ def test_all_layers_standard(self): """Test that model with no quantization works normally""" # Configure no quantization ops.MixedPrecisionOps._layer_quant_config = {} - + # Create model model = SimpleModel(operations=ops.MixedPrecisionOps) - + # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) @@ -43,19 +43,19 @@ def test_all_layers_standard(self): model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) - + # Initialize weight_function and bias_function for layer in [model.layer1, model.layer2, model.layer3]: layer.weight_function = [] layer.bias_function = [] - + # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) self.assertEqual(output.dtype, torch.bfloat16) - + def test_mixed_precision_load(self): """Test loading a mixed precision model from state dict""" # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard @@ -70,52 +70,52 @@ def test_mixed_precision_load(self): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) - + state_dict = { # Layer 1: FP8 E4M3FN "layer1.weight": fp8_weight1, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), - + # Layer 2: Standard BF16 "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), - + # Layer 3: FP8 E4M3FN "layer3.weight": fp8_weight3, "layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } - + # Create model and load state dict (strict=False because custom loading pops keys) model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict, strict=False) - + # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) - + # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) - + # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) - + # Verify scales were loaded self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) - + # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) - + def test_state_dict_quantized_preserved(self): """Test that quantized weights are preserved in state_dict()""" # Configure mixed precision @@ -126,7 +126,7 @@ def test_state_dict_quantized_preserved(self): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict1 = { @@ -138,22 +138,22 @@ def test_state_dict_quantized_preserved(self): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict1, strict=False) - + # Save state dict state_dict2 = model.state_dict() - + # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) - + # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) - + def test_weight_function_compatibility(self): """Test that weight_function (LoRA) works with quantized layers""" # Configure FP8 quantization @@ -164,7 +164,7 @@ def test_weight_function_compatibility(self): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict = { @@ -176,24 +176,24 @@ def test_weight_function_compatibility(self): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict, strict=False) - + # Add a weight function (simulating LoRA) # This should trigger dequantization during forward pass def apply_lora(weight): lora_delta = torch.randn_like(weight) * 0.01 return weight + lora_delta - + model.layer1.weight_function.append(apply_lora) - + # Forward pass should work with LoRA (triggers weight_function path) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) - + def test_error_handling_unknown_format(self): """Test that unknown formats raise error""" # Configure with unknown format @@ -204,7 +204,7 @@ def test_error_handling_unknown_format(self): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create state dict state_dict = { "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), @@ -214,7 +214,7 @@ def test_error_handling_unknown_format(self): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS model = SimpleModel(operations=ops.MixedPrecisionOps) with self.assertRaises(KeyError): diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 263581417177..26e91a7ee7d0 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -11,51 +11,51 @@ class TestQuantizedTensor(unittest.TestCase): """Test the QuantizedTensor subclass with FP8 layout""" - + def test_creation(self): """Test creating a QuantizedTensor with TensorCoreFP8Layout""" fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(2.0) layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) self.assertEqual(qt._layout_type, TensorCoreFP8Layout) - + def test_dequantize(self): """Test explicit dequantization""" fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(3.0) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) dequantized = qt.dequantize() - + self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) scale = torch.tensor(1.5) - + qt = QuantizedTensor.from_float( - float_tensor, - TensorCoreFP8Layout, + float_tensor, + TensorCoreFP8Layout, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt.shape, (64, 32)) - + # Verify dequantization gives approximately original values dequantized = qt.dequantize() mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() @@ -64,48 +64,48 @@ def test_from_float(self): class TestGenericUtilities(unittest.TestCase): """Test generic utility operations""" - + def test_detach(self): """Test detach operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Detach should return a new QuantizedTensor qt_detached = qt.detach() - + self.assertIsInstance(qt_detached, QuantizedTensor) self.assertEqual(qt_detached.shape, qt.shape) self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) - + def test_clone(self): """Test clone operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Clone should return a new QuantizedTensor qt_cloned = qt.clone() - + self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertEqual(qt_cloned.shape, qt.shape) self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) - + # Verify it's a deep copy self.assertIsNot(qt_cloned._qdata, qt._qdata) - + def test_to_device(self): """Test device transfer""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Moving to same device should work (CPU to CPU) qt_cpu = qt.to('cpu') - + self.assertIsInstance(qt_cpu, QuantizedTensor) self.assertEqual(qt_cpu.device.type, 'cpu') self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') @@ -113,64 +113,63 @@ def test_to_device(self): class TestTensorCoreFP8Layout(unittest.TestCase): """Test the TensorCoreFP8Layout implementation""" - + def test_quantize(self): """Test quantization method""" float_tensor = torch.randn(32, 64, dtype=torch.float32) scale = torch.tensor(1.5) - + qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) self.assertEqual(qdata.shape, float_tensor.shape) self.assertIn('scale', layout_params) self.assertIn('orig_dtype', layout_params) self.assertEqual(layout_params['orig_dtype'], torch.float32) - + def test_dequantize(self): """Test dequantization method""" float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 scale = torch.tensor(1.0) - + qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - + # Should approximately match original self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) class TestFallbackMechanism(unittest.TestCase): """Test fallback for unsupported operations""" - + def test_unsupported_op_dequantizes(self): """Test that unsupported operations fall back to dequantization""" # Set seed for reproducibility torch.manual_seed(42) - + # Create quantized tensor a_fp32 = torch.randn(10, 20, dtype=torch.float32) scale = torch.tensor(1.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} a_q = QuantizedTensor.from_float( a_fp32, TensorCoreFP8Layout, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + # Call an operation that doesn't have a registered handler # For example, torch.abs result = torch.abs(a_q) - + # Should work via fallback (dequantize → abs → return) self.assertNotIsInstance(result, QuantizedTensor) expected = torch.abs(a_fp32) From f287d02419ddbd929fe58348bd99f881c4badef7 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 10:04:57 +0100 Subject: [PATCH 46/49] Fix missing keys --- comfy/ops.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index a5e9a4d92e50..d7a8873e2a41 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -534,6 +534,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, if weight is None: raise ValueError(f"Missing weight for layer {layer_name}") + manually_loaded_keys = [weight_key] + if layer_name not in MixedPrecisionOps._layer_quant_config: self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: @@ -544,23 +546,33 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, mixin = QUANT_FORMAT_MIXINS[quant_format] self.layout_type = mixin["layout_type"] + scale_key = f"{prefix}weight_scale" layout_params = { - 'scale': state_dict.pop(f"{prefix}weight_scale", None), + 'scale': state_dict.pop(scale_key, None), 'orig_dtype': MixedPrecisionOps._compute_dtype } + if layout_params['scale'] is not None: + manually_loaded_keys.append(scale_key) + self.weight = torch.nn.Parameter( QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), requires_grad=False ) for param_name, param_value in mixin["parameters"].items(): - _v = state_dict.pop(f"{prefix}{param_name}", None) + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) if _v is None: continue - setattr(self, param_name, _v.to(device=device)) + setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) From 59a2e8c74e1df09c5590abd95e4908c9704a687c Mon Sep 17 00:00:00 2001 From: lspindler Date: Tue, 28 Oct 2025 07:33:19 +0100 Subject: [PATCH 47/49] Rename quant dtype parameter --- comfy/ops.py | 2 +- comfy/quant_ops.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index d7a8873e2a41..93731eedf276 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -379,7 +379,7 @@ def fp8_linear(self, input): # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) - quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) if tensor_2d: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index aa1a231bd15c..b14e03084e8f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -347,20 +347,20 @@ class TensorCoreFP8Layout(QuantizedLayout): - orig_dtype: Original dtype before quantization (for casting back) """ @classmethod - def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn): + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): orig_dtype = tensor.dtype if scale is None: - scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max + scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - lp_amax = torch.finfo(fp8_dtype).max + lp_amax = torch.finfo(dtype).max tensor_scaled = tensor.float() / scale torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = { 'scale': scale, From 135d3025ea0bf5b65ebd78558503f9237683c233 Mon Sep 17 00:00:00 2001 From: lspindler Date: Tue, 28 Oct 2025 07:33:42 +0100 Subject: [PATCH 48/49] Rename quant dtype parameter --- tests-unit/comfy_quant/test_quant_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 26e91a7ee7d0..2d7d3fa28e68 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -122,7 +122,7 @@ def test_quantize(self): qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, - fp8_dtype=torch.float8_e4m3fn + dtype=torch.float8_e4m3fn ) self.assertEqual(qdata.dtype, torch.float8_e4m3fn) @@ -139,7 +139,7 @@ def test_dequantize(self): qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, - fp8_dtype=torch.float8_e4m3fn + dtype=torch.float8_e4m3fn ) dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) From 9d9f98cb728c729cf3c20f4fd138997c315558ce Mon Sep 17 00:00:00 2001 From: lspindler Date: Tue, 28 Oct 2025 08:02:26 +0100 Subject: [PATCH 49/49] Fix unittests for CPU build --- tests-unit/comfy_quant/test_mixed_precision.py | 7 +++++++ tests-unit/comfy_quant/test_quant_registry.py | 12 ++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 1102f9bd4b28..267bc177b786 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -6,6 +6,13 @@ # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + from comfy import ops from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 2d7d3fa28e68..47781102947c 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -6,6 +6,13 @@ # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout @@ -49,7 +56,7 @@ def test_from_float(self): float_tensor, TensorCoreFP8Layout, scale=scale, - fp8_dtype=torch.float8_e4m3fn + dtype=torch.float8_e4m3fn ) self.assertIsInstance(qt, QuantizedTensor) @@ -96,6 +103,7 @@ def test_clone(self): # Verify it's a deep copy self.assertIsNot(qt_cloned._qdata, qt._qdata) + @unittest.skipUnless(has_gpu(), "GPU not available") def test_to_device(self): """Test device transfer""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -163,7 +171,7 @@ def test_unsupported_op_dequantizes(self): a_fp32, TensorCoreFP8Layout, scale=scale, - fp8_dtype=torch.float8_e4m3fn + dtype=torch.float8_e4m3fn ) # Call an operation that doesn't have a registered handler