@@ -3123,6 +3123,41 @@ def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
3123
3123
# update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7
3124
3124
self .assertTrue (assistant_model .generation_config .num_assistant_tokens in (4 , 7 ))
3125
3125
3126
+ def test_assisted_decoding_parameter_inheritance (self ):
3127
+ # This test ensures that assistant models inherit generation parameters from the main generate() call.
3128
+ # Before the fix, assistant models would use their default values instead of user-specified values.
3129
+
3130
+ prompt = "Alice and Bob"
3131
+ checkpoint = "EleutherAI/pythia-160m-deduped"
3132
+ tokenizer = AutoTokenizer .from_pretrained (checkpoint )
3133
+ inputs = tokenizer (prompt , return_tensors = "pt" )
3134
+
3135
+ model = AutoModelForCausalLM .from_pretrained (checkpoint )
3136
+ assistant_model = AutoModelForCausalLM .from_pretrained (checkpoint )
3137
+
3138
+ # Check assistant model defaults
3139
+ self .assertEqual (assistant_model .generation_config .num_assistant_tokens , 20 )
3140
+ self .assertEqual (assistant_model .generation_config .assistant_confidence_threshold , 0.4 )
3141
+ self .assertEqual (assistant_model .generation_config .do_sample , False )
3142
+
3143
+ # Generate with user-specified values that differ from assistant defaults
3144
+ generation_kwargs = {
3145
+ "eos_token_id" : - 1 ,
3146
+ "max_new_tokens" : 5 ,
3147
+ "assistant_model" : assistant_model ,
3148
+ "do_sample" : True ,
3149
+ "num_assistant_tokens" : 7 ,
3150
+ "assistant_confidence_threshold" : 0.8 ,
3151
+ }
3152
+
3153
+ model .generate (** inputs , ** generation_kwargs )
3154
+
3155
+ # After generation, assistant model should have the user-specified values, not its defaults
3156
+ # Inheritance applies to all main model parameters, not just ones that have "assistant" slots
3157
+ self .assertEqual (assistant_model .generation_config .num_assistant_tokens , 7 )
3158
+ self .assertEqual (assistant_model .generation_config .assistant_confidence_threshold , 0.8 )
3159
+ self .assertEqual (assistant_model .generation_config .do_sample , True )
3160
+
3126
3161
def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule (self ):
3127
3162
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
3128
3163
0 commit comments