diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 36210b398906..07b340144653 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -507,21 +507,22 @@ def _cache_dependant_input_preparation_exporting( # else: # if input_ids.shape[1] != cache_position.shape[0]: # input_ids = input_ids[:, cache_position] + # We need to clone the outputs to avoid aliasing. def branch_1(inputs_embeds, cache_position): - return inputs_embeds[:, -cache_position.shape[0] :] + return inputs_embeds[:, -cache_position.shape[0] :].clone() def branch_2(input_ids, cache_position): - return input_ids[:, -cache_position.shape[0] :] + return input_ids[:, -cache_position.shape[0] :].clone() def branch_3(input_ids, cache_position): - return input_ids[:, cache_position] + return input_ids[:, cache_position].clone() inputs_embeds, input_ids = torch.cond( input_ids.shape[1] == 0, ( lambda input_ids, inputs_embeds, cache_position: ( branch_1(inputs_embeds, cache_position), - input_ids, + input_ids.clone(), ) ), ( @@ -534,7 +535,7 @@ def branch_3(input_ids, cache_position): torch.cond( input_ids.shape[1] != cache_position.shape[0], branch_3, - (lambda input_ids, cache_position: input_ids), + (lambda input_ids, cache_position: input_ids.clone()), [input_ids, cache_position], ) ),