@@ -3088,6 +3088,41 @@ def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
3088
3088
# update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7
3089
3089
self .assertTrue (assistant_model .generation_config .num_assistant_tokens in (4 , 7 ))
3090
3090
3091
+ def test_assisted_decoding_parameter_inheritance (self ):
3092
+ # This test ensures that assistant models inherit generation parameters from the main generate() call.
3093
+ # Before the fix, assistant models would use their default values instead of user-specified values.
3094
+
3095
+ prompt = "Alice and Bob"
3096
+ checkpoint = "EleutherAI/pythia-160m-deduped"
3097
+ tokenizer = AutoTokenizer .from_pretrained (checkpoint )
3098
+ inputs = tokenizer (prompt , return_tensors = "pt" )
3099
+
3100
+ model = AutoModelForCausalLM .from_pretrained (checkpoint )
3101
+ assistant_model = AutoModelForCausalLM .from_pretrained (checkpoint )
3102
+
3103
+ # Check assistant model defaults
3104
+ self .assertEqual (assistant_model .generation_config .num_assistant_tokens , 20 )
3105
+ self .assertEqual (assistant_model .generation_config .assistant_confidence_threshold , 0.4 )
3106
+ self .assertEqual (assistant_model .generation_config .do_sample , False )
3107
+
3108
+ # Generate with user-specified values that differ from assistant defaults
3109
+ generation_kwargs = {
3110
+ "eos_token_id" : - 1 ,
3111
+ "max_new_tokens" : 5 ,
3112
+ "assistant_model" : assistant_model ,
3113
+ "do_sample" : True ,
3114
+ "num_assistant_tokens" : 7 ,
3115
+ "assistant_confidence_threshold" : 0.8 ,
3116
+ }
3117
+
3118
+ model .generate (** inputs , ** generation_kwargs )
3119
+
3120
+ # After generation, assistant model should have the user-specified values, not its defaults
3121
+ # Inheritance applies to all main model parameters, not just ones that have "assistant" slots
3122
+ self .assertEqual (assistant_model .generation_config .num_assistant_tokens , 7 )
3123
+ self .assertEqual (assistant_model .generation_config .assistant_confidence_threshold , 0.8 )
3124
+ self .assertEqual (assistant_model .generation_config .do_sample , True )
3125
+
3091
3126
def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule (self ):
3092
3127
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
3093
3128
0 commit comments