@@ -367,13 +367,13 @@ def to_neuron(self):
367367 ln_lm_head_params .append (self .logits_indices )
368368 if self .neuron_config .on_device_generation .dynamic :
369369 config = self .neuron_config .on_device_generation
370- self .top_k = manipulator .duplicate (torch .tensor (config .top_k ))
370+ self .top_k = manipulator .duplicate (torch .tensor (config .top_k , dtype = torch . int32 ))
371371 self .generation_inputs .append (self .top_k )
372- self .top_p = manipulator .duplicate (torch .tensor (config .top_p ))
372+ self .top_p = manipulator .duplicate (torch .tensor (config .top_p , dtype = torch . float32 ))
373373 self .generation_inputs .append (self .top_p )
374- self .temperature = manipulator .duplicate (torch .tensor (config .temperature ))
374+ self .temperature = manipulator .duplicate (torch .tensor (config .temperature , dtype = torch . float32 ))
375375 self .generation_inputs .append (self .temperature )
376- self .top_p_min_tokens = manipulator .duplicate (torch .tensor (config .top_p_min_tokens ))
376+ self .top_p_min_tokens = manipulator .duplicate (torch .tensor (config .top_p_min_tokens , dtype = torch . int32 ))
377377 self .generation_inputs .append (self .top_p_min_tokens )
378378 # FIXME: Use a better mechanism to pass extra params into the model
379379 ln_lm_head_params += self .generation_inputs
@@ -1107,11 +1107,11 @@ def validate_generation_configs(self, generation_config: GenerationConfig):
11071107 def update_generation_config (self , generation_config : config .GenerationConfig ):
11081108 self .validate_generation_configs (generation_config )
11091109 num_cores = self .neuron_config .get_local_tp (self .tp_degree )
1110- duplicate = lambda tensor : [torch .tensor (tensor ) for _ in range (num_cores )]
1111- ops .parallel_write (self .top_k , duplicate (generation_config .top_k ))
1112- ops .parallel_write (self .top_p , duplicate (generation_config .top_p ))
1113- ops .parallel_write (self .temperature , duplicate (generation_config .temperature ))
1114- ops .parallel_write (self .top_p_min_tokens , duplicate (generation_config .top_p_min_tokens ))
1110+ duplicate = lambda tensor , dtype : [torch .tensor (tensor , dtype = dtype ) for _ in range (num_cores )]
1111+ ops .parallel_write (self .top_k , duplicate (generation_config .top_k , dtype = torch . int32 ))
1112+ ops .parallel_write (self .top_p , duplicate (generation_config .top_p , dtype = torch . float32 ))
1113+ ops .parallel_write (self .temperature , duplicate (generation_config .temperature , dtype = torch . float32 ))
1114+ ops .parallel_write (self .top_p_min_tokens , duplicate (generation_config .top_p_min_tokens , dtype = torch . int32 ))
11151115
11161116
11171117def read_n_position (hlo_module , num_inputs ):
0 commit comments