Skip to content

Commit 39462a4

Browse files
committed
Improve test patterns for QwenImage long prompt warning
- Move CaptureLogger import to top level following established patterns - Use logging.WARNING constant instead of hardcoded value - Simplify device handling to match other QwenImage tests - Remove redundant variable assignments and comments
1 parent c7ac380 commit 39462a4

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

tests/pipelines/qwenimage/test_qwenimage.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
QwenImagePipeline,
2525
QwenImageTransformer2DModel,
2626
)
27-
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
27+
from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, torch_device
2828

2929
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3030
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -264,27 +264,24 @@ def test_long_prompt_no_error(self):
264264

265265
def test_long_prompt_warning(self):
266266
"""Test that long prompts trigger appropriate warning about training limitation"""
267-
from diffusers.utils.testing_utils import CaptureLogger
268267
from diffusers.utils import logging
269268

270-
device = torch_device
271269
components = self.get_dummy_components()
272270
pipe = self.pipeline_class(**components)
273-
pipe.to(device)
271+
pipe.to(torch_device)
274272

275273
# Create prompt that will exceed 512 tokens to trigger warning
276-
# Use a longer phrase and repeat more times to ensure we exceed the 512 token limit
277274
long_phrase = "A detailed photorealistic description of a complex scene with many elements "
278275
long_prompt = (long_phrase * 20)[:800] # Create a prompt that will exceed 512 tokens
279276

280277
# Capture transformer logging
281278
logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage")
282-
logger.setLevel(30) # WARNING level
279+
logger.setLevel(logging.WARNING)
283280

284281
with CaptureLogger(logger) as cap_logger:
285282
_ = pipe(
286283
prompt=long_prompt,
287-
generator=torch.Generator(device=device).manual_seed(0),
284+
generator=torch.Generator(device=torch_device).manual_seed(0),
288285
num_inference_steps=2,
289286
guidance_scale=3.0,
290287
true_cfg_scale=1.0,

0 commit comments

Comments
 (0)