Skip to content

Commit c7ac380

Browse files
committed
Add training limitation warning for QwenImage long prompts
- Add warning when prompts exceed 512 tokens (model's training limit) - Warn users about potential unpredictable behavior with long prompts - Add comprehensive test with CaptureLogger to verify warning system - Follow established diffusers warning patterns for consistency
1 parent 5b516b0 commit c7ac380

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def _expand_pos_freqs_if_needed(self, required_len):
203203
# Calculate new size (use next power of 2 or round to nearest 512 for efficiency)
204204
new_max_len = max(required_len, int((required_len + 511) // 512) * 512)
205205

206+
# Log warning about potential quality degradation for long prompts
207+
if required_len > 512:
208+
logger.warning(
209+
f"QwenImage model was trained on prompts up to 512 tokens. "
210+
f"Current prompt requires {required_len} tokens, which may lead to unpredictable behavior. "
211+
f"Consider using shorter prompts for better results."
212+
)
213+
206214
# Generate expanded indices
207215
pos_index = torch.arange(new_max_len, device=self.pos_freqs.device)
208216
neg_index = torch.arange(new_max_len, device=self.neg_freqs.device).flip(0) * -1 - 1

tests/pipelines/qwenimage/test_qwenimage.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,4 +260,40 @@ def test_long_prompt_no_error(self):
260260
}
261261

262262
# This should not raise a RuntimeError about tensor dimension mismatch
263-
_ = pipe(**inputs)
263+
_ = pipe(**inputs)
264+
265+
def test_long_prompt_warning(self):
266+
"""Test that long prompts trigger appropriate warning about training limitation"""
267+
from diffusers.utils.testing_utils import CaptureLogger
268+
from diffusers.utils import logging
269+
270+
device = torch_device
271+
components = self.get_dummy_components()
272+
pipe = self.pipeline_class(**components)
273+
pipe.to(device)
274+
275+
# Create prompt that will exceed 512 tokens to trigger warning
276+
# Use a longer phrase and repeat more times to ensure we exceed the 512 token limit
277+
long_phrase = "A detailed photorealistic description of a complex scene with many elements "
278+
long_prompt = (long_phrase * 20)[:800] # Create a prompt that will exceed 512 tokens
279+
280+
# Capture transformer logging
281+
logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage")
282+
logger.setLevel(30) # WARNING level
283+
284+
with CaptureLogger(logger) as cap_logger:
285+
_ = pipe(
286+
prompt=long_prompt,
287+
generator=torch.Generator(device=device).manual_seed(0),
288+
num_inference_steps=2,
289+
guidance_scale=3.0,
290+
true_cfg_scale=1.0,
291+
height=32, # Small size for fast test
292+
width=32, # Small size for fast test
293+
max_sequence_length=900, # Allow long sequence
294+
output_type="pt"
295+
)
296+
297+
# Verify warning was logged about the 512-token training limitation
298+
self.assertTrue("512 tokens" in cap_logger.out)
299+
self.assertTrue("unpredictable behavior" in cap_logger.out)

0 commit comments

Comments
 (0)