Skip to content

Commit b54be08

Browse files
Fix num_assistant_tokens not configured in assistant model
#40739
1 parent 3378e7d commit b54be08

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/transformers/generation/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,6 +2369,20 @@ def generate(
23692369
self._validate_model_kwargs(model_kwargs.copy())
23702370
self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
23712371

2372+
# Configure assistant model's generation_config with user parameters
2373+
if assistant_model is not None:
2374+
# The assistant model inherits ALL generation parameters from the main generate() call, including:
2375+
# - Assistant-specific parameters (num_assistant_tokens, assistant_confidence_threshold, etc.)
2376+
# - General generation parameters (do_sample, max_new_tokens, temperature, etc.)
2377+
# This ensures consistent behavior between main and assistant models. In the future,
2378+
# assistant-specific overrides could be added (e.g., assistant_do_sample) to allow
2379+
# different generation strategies for draft vs target models while maintaining the
2380+
# inheritance-by-default behavior.
2381+
assistant_generation_config, _ = assistant_model._prepare_generation_config(
2382+
assistant_model.generation_config, use_model_defaults, **kwargs
2383+
)
2384+
assistant_model.generation_config = assistant_generation_config
2385+
23722386
# Deprecation-related step: set Hub repo for deprecated strategies.
23732387
# NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
23742388
# It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.

0 commit comments

Comments
 (0)