Skip to content

Commit 30c2ed5

Browse files
Add tests for issue 40739
#40739
1 parent b54be08 commit 30c2ed5

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/generation/test_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3088,6 +3088,41 @@ def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
30883088
# update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7
30893089
self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7))
30903090

3091+
def test_assisted_decoding_parameter_inheritance(self):
3092+
# This test ensures that assistant models inherit generation parameters from the main generate() call.
3093+
# Before the fix, assistant models would use their default values instead of user-specified values.
3094+
3095+
prompt = "Alice and Bob"
3096+
checkpoint = "EleutherAI/pythia-160m-deduped"
3097+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
3098+
inputs = tokenizer(prompt, return_tensors="pt")
3099+
3100+
model = AutoModelForCausalLM.from_pretrained(checkpoint)
3101+
assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint)
3102+
3103+
# Check assistant model defaults
3104+
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 20)
3105+
self.assertEqual(assistant_model.generation_config.assistant_confidence_threshold, 0.4)
3106+
self.assertEqual(assistant_model.generation_config.do_sample, False)
3107+
3108+
# Generate with user-specified values that differ from assistant defaults
3109+
generation_kwargs = {
3110+
"eos_token_id": -1,
3111+
"max_new_tokens": 5,
3112+
"assistant_model": assistant_model,
3113+
"do_sample": True,
3114+
"num_assistant_tokens": 7,
3115+
"assistant_confidence_threshold": 0.8,
3116+
}
3117+
3118+
model.generate(**inputs, **generation_kwargs)
3119+
3120+
# After generation, assistant model should have the user-specified values, not its defaults
3121+
# Inheritance applies to all main model parameters, not just ones that have "assistant" slots
3122+
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 7)
3123+
self.assertEqual(assistant_model.generation_config.assistant_confidence_threshold, 0.8)
3124+
self.assertEqual(assistant_model.generation_config.do_sample, True)
3125+
30913126
def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self):
30923127
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
30933128

0 commit comments

Comments
 (0)