@@ -2369,6 +2369,20 @@ def generate(
2369
2369
self ._validate_model_kwargs (model_kwargs .copy ())
2370
2370
self ._validate_generation_mode (generation_mode , generation_config , generation_mode_kwargs )
2371
2371
2372
+ # Configure assistant model's generation_config with user parameters
2373
+ if assistant_model is not None :
2374
+ # The assistant model inherits ALL generation parameters from the main generate() call, including:
2375
+ # - Assistant-specific parameters (num_assistant_tokens, assistant_confidence_threshold, etc.)
2376
+ # - General generation parameters (do_sample, max_new_tokens, temperature, etc.)
2377
+ # This ensures consistent behavior between main and assistant models. In the future,
2378
+ # assistant-specific overrides could be added (e.g., assistant_do_sample) to allow
2379
+ # different generation strategies for draft vs target models while maintaining the
2380
+ # inheritance-by-default behavior.
2381
+ assistant_generation_config , _ = assistant_model ._prepare_generation_config (
2382
+ assistant_model .generation_config , use_model_defaults , ** kwargs
2383
+ )
2384
+ assistant_model .generation_config = assistant_generation_config
2385
+
2372
2386
# Deprecation-related step: set Hub repo for deprecated strategies.
2373
2387
# NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
2374
2388
# It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
0 commit comments