diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index f2fc015e948f..894eac89a78f 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -209,7 +209,7 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: # Get all loading fields in order loading_fields = cls.loading_fields() - result = {f: None for f in loading_fields} + result = dict.fromkeys(loading_fields) if load_id == "null": return result diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index dbeff3de5652..fd1c0bea33f3 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -406,7 +406,6 @@ def step( return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, @@ -415,23 +414,30 @@ def add_noise( ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] + # Handle integer timesteps (training case) + if timesteps.dtype in (torch.int32, torch.int64): + # Training: reverse mapping since EDM sigmas are in descending order + # timestep 0 -> sigma_min, timestep 999 -> sigma_max + step_indices = self.config.num_train_timesteps - 1 - timesteps.long() else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = torch.tensor([self.index_for_timestep(t, schedule_timesteps) for t in timesteps]) + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = torch.tensor([self.step_index] * timesteps.shape[0]) + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = torch.tensor([self.begin_index] * timesteps.shape[0]) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/tests/schedulers/test_scheduler_edm_euler.py b/tests/schedulers/test_scheduler_edm_euler.py index acac4b1f4cae..699145e37b81 100644 --- a/tests/schedulers/test_scheduler_edm_euler.py +++ b/tests/schedulers/test_scheduler_edm_euler.py @@ -80,6 +80,67 @@ def test_full_loop_device(self, num_inference_steps=10, seed=0): assert abs(result_sum.item() - 34.1855) < 1e-3 assert abs(result_mean.item() - 0.044) < 1e-3 + def test_add_noise_with_integer_timesteps(self): + """Test that add_noise works with integer timesteps (training case) - Issue #7406""" + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + batch_size = 4 + channels = 3 + height = width = 32 + + # Create dummy data + original_samples = torch.randn(batch_size, channels, height, width) + noise = torch.randn_like(original_samples) + + # Test with integer timesteps (training case) + timesteps = torch.randint(0, scheduler_config["num_train_timesteps"], (batch_size,), dtype=torch.int64) + + # This should not raise an error + noisy_samples = scheduler.add_noise(original_samples, noise, timesteps) + + # Verify output shape + self.assertEqual(noisy_samples.shape, original_samples.shape) + + # Verify that noise was actually added + self.assertFalse(torch.allclose(noisy_samples, original_samples)) + + # Verify noise levels are correct (higher timestep = more noise) + t_low = torch.tensor([0], dtype=torch.int64) + t_high = torch.tensor([scheduler_config["num_train_timesteps"] - 1], dtype=torch.int64) + + noisy_low = scheduler.add_noise(original_samples[:1], noise[:1], t_low) + noisy_high = scheduler.add_noise(original_samples[:1], noise[:1], t_high) + + noise_low = (noisy_low - original_samples[:1]).abs().mean() + noise_high = (noisy_high - original_samples[:1]).abs().mean() + + # Higher timestep should have more noise + self.assertGreater(noise_high, noise_low) + + def test_add_noise_with_float_timesteps(self): + """Test that add_noise still works with float timesteps (inference case)""" + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(50) + + batch_size = 2 + + # Create dummy data with correct shape + original_samples = torch.randn(batch_size, 3, 32, 32) + noise = torch.randn_like(original_samples) + + # Use float timesteps from the scheduler + timesteps = scheduler.timesteps[:batch_size] + + # This should not raise an error + noisy_samples = scheduler.add_noise(original_samples, noise, timesteps) + + # Verify output shape + self.assertEqual(noisy_samples.shape, original_samples.shape) + # Override test_from_save_pretrained to use EDMEulerScheduler-specific logic def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) @@ -115,7 +176,6 @@ def test_from_save_pretrained(self): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - # Override test_from_save_pretrained to use EDMEulerScheduler-specific logic def test_step_shape(self): num_inference_steps = 10 @@ -137,7 +197,6 @@ def test_step_shape(self): self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) - # Override test_from_save_pretrained to use EDMEulerScheduler-specific logic def test_scheduler_outputs_equivalence(self): def set_nan_tensor_to_zero(t): t[t != t] = 0