Skip to content

Commit 3796944

Browse files
Add tests for issue 40739
#40739
1 parent 3341aa8 commit 3796944

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/generation/test_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,6 +3123,41 @@ def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
31233123
# update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7
31243124
self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7))
31253125

3126+
def test_assisted_decoding_parameter_inheritance(self):
3127+
# This test ensures that assistant models inherit generation parameters from the main generate() call.
3128+
# Before the fix, assistant models would use their default values instead of user-specified values.
3129+
3130+
prompt = "Alice and Bob"
3131+
checkpoint = "EleutherAI/pythia-160m-deduped"
3132+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
3133+
inputs = tokenizer(prompt, return_tensors="pt")
3134+
3135+
model = AutoModelForCausalLM.from_pretrained(checkpoint)
3136+
assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint)
3137+
3138+
# Check assistant model defaults
3139+
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 20)
3140+
self.assertEqual(assistant_model.generation_config.assistant_confidence_threshold, 0.4)
3141+
self.assertEqual(assistant_model.generation_config.do_sample, False)
3142+
3143+
# Generate with user-specified values that differ from assistant defaults
3144+
generation_kwargs = {
3145+
"eos_token_id": -1,
3146+
"max_new_tokens": 5,
3147+
"assistant_model": assistant_model,
3148+
"do_sample": True,
3149+
"num_assistant_tokens": 7,
3150+
"assistant_confidence_threshold": 0.8,
3151+
}
3152+
3153+
model.generate(**inputs, **generation_kwargs)
3154+
3155+
# After generation, assistant model should have the user-specified values, not its defaults
3156+
# Inheritance applies to all main model parameters, not just ones that have "assistant" slots
3157+
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 7)
3158+
self.assertEqual(assistant_model.generation_config.assistant_confidence_threshold, 0.8)
3159+
self.assertEqual(assistant_model.generation_config.do_sample, True)
3160+
31263161
def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self):
31273162
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
31283163

0 commit comments

Comments
 (0)