@@ -698,7 +698,6 @@ def _hlo_eagle_draft_unroll(self, hidden, tensors, layers_caches, layers_weights
698698 return logits , hidden , out_caches
699699
700700 def _hlo_fully_unrolled (self , n_positions , batch_size ):
701-
702701 self .builder .n_positions = n_positions
703702 if self .neuron_config .optimized_paged_attention and self .n_active_tokens == 1 :
704703 self .builder .num_active_blocks = batch_size
@@ -733,7 +732,7 @@ def fully_unrolled(scribe):
733732 else :
734733 logits , out_caches = self ._hlo_unroll (hidden , tensors , in_caches , layers_weights , pre_layer_params , lm_head_params )
735734 self ._hlo_cache_aliases (in_caches , out_caches )
736- output = self ._hlo_generation (logits , generation_params )
735+ output = self ._hlo_generation (logits , generation_params , start_ids = tensors [ 1 ] )
737736
738737 # Set the output
739738 out_caches = itertools .chain (* out_caches )
@@ -1006,11 +1005,12 @@ def ln_lm_head(scribe):
10061005 next_tok_id = scribe .s32 [batch_size ].Parameter (parameter_number = 1 )
10071006 else :
10081007 next_tok_id = scribe .s32 [1 ].Parameter (parameter_number = 1 )
1009- param_builder = DecoderParameterBuilder (scribe , 2 )
1008+ start_ids = scribe .s32 [batch_size ].Parameter (parameter_number = 2 )
1009+ param_builder = DecoderParameterBuilder (scribe , 3 )
10101010 ln_f_weight , ln_f_bias , head_weight , head_bias = self ._hlo_lm_head_params (param_builder )
1011- gneration_params = self ._hlo_generation_params (param_builder )
1011+ generation_params = self ._hlo_generation_params (param_builder )
10121012 logits = self .ln_lm_head_builder (hidden , next_tok_id , ln_f_weight , ln_f_bias , head_weight , head_bias , return_all_outputs = self .return_all_outputs )
1013- output = self ._hlo_generation (logits , gneration_params )
1013+ output = self ._hlo_generation (logits , generation_params , start_ids = start_ids )
10141014 if self .neuron_config .log_softmax_scores :
10151015 logits , scores = self ._hlo_post_layer (logits )
10161016 outputs = [logits , scores ]
@@ -1034,7 +1034,7 @@ def _hlo_generation_params(self, param_builder):
10341034 params .append (param )
10351035 return params
10361036
1037- def _hlo_generation (self , logits , params , early_return = False , return_probs = False ):
1037+ def _hlo_generation (self , logits , params , early_return = False , return_probs = False , start_ids = None ):
10381038 generation_config = self .neuron_config .on_device_generation
10391039 if generation_config is None :
10401040 return logits
@@ -1045,13 +1045,16 @@ def _hlo_generation(self, logits, params, early_return=False, return_probs=False
10451045 self .neuron_config .on_device_generation .top_p = top_p
10461046 self .neuron_config .on_device_generation .temperature = temperature
10471047 self .neuron_config .on_device_generation .top_p_min_tokens = top_p_min_tokens
1048+
1049+ seq_ids = start_ids if self .neuron_config .continuous_batching is not None else None
10481050 return generation .generate (
10491051 logits ,
10501052 logits_indices ,
10511053 config = generation_config ,
10521054 tp_degree = self .tp_degree ,
10531055 early_return = early_return ,
10541056 return_probs = return_probs ,
1057+ seq_ids = seq_ids ,
10551058 )
10561059
10571060 # Mainly used for serialization purposes.
@@ -2458,10 +2461,12 @@ def setup(self, layers, pre_layer_params, ln_lm_head_params):
24582461
24592462 hidden_buffers = list ()
24602463 last_token_id_buffers = list ()
2464+ start_ids_buffers = list ()
24612465 for input_buffer in self .input_buffers :
2462- hidden_buffer , _ , _ , last_token_id_buffer , * _ = input_buffer
2466+ hidden_buffer , _ , start_ids_buffer , last_token_id_buffer , * _ = input_buffer
24632467 hidden_buffers .append (hidden_buffer )
24642468 last_token_id_buffers .append (last_token_id_buffer )
2469+ start_ids_buffers .append (start_ids_buffer )
24652470
24662471 multi_layer_starts = range (0 , len (layers ), self .unroll )
24672472 multi_layers = [layers [start :start + self .unroll ] for start in multi_layer_starts ]
@@ -2480,7 +2485,7 @@ def setup(self, layers, pre_layer_params, ln_lm_head_params):
24802485 if self .neuron_config .is_valid_lm_head ():
24812486 for head_idx in range (0 ,len (self .ln_lm_head_kernels )):
24822487 output_tensors = [* self .logits_buffer [head_idx ]] if self .neuron_config .log_softmax_scores or self .neuron_config .is_eagle_target else [self .logits_buffer [head_idx ]]
2483- self .ln_lm_head_memories [head_idx ].setup ([hidden_buffers [head_idx ], last_token_id_buffers [head_idx ], * ln_lm_head_params ], output_tensors )
2488+ self .ln_lm_head_memories [head_idx ].setup ([hidden_buffers [head_idx ], last_token_id_buffers [head_idx ], start_ids_buffers [ head_idx ], * ln_lm_head_params ], output_tensors )
24842489 self .ln_lm_head_kernels [head_idx ].build ()
24852490 self .ln_lm_head_kernels [head_idx ].load ()
24862491
0 commit comments