Skip to content

Commit 66d4231

Browse files
aws-patlangeaws-yishanm
authored andcommitted
[generation] Enforce desired dtype for on_device_generation sampling params
GitOrigin-RevId: 13cd995f8a2aff8a66a39cc7a828e0d05c67d39e
1 parent fd1a28b commit 66d4231

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/transformers_neuronx/decoder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,13 +367,13 @@ def to_neuron(self):
367367
ln_lm_head_params.append(self.logits_indices)
368368
if self.neuron_config.on_device_generation.dynamic:
369369
config = self.neuron_config.on_device_generation
370-
self.top_k = manipulator.duplicate(torch.tensor(config.top_k))
370+
self.top_k = manipulator.duplicate(torch.tensor(config.top_k, dtype=torch.int32))
371371
self.generation_inputs.append(self.top_k)
372-
self.top_p = manipulator.duplicate(torch.tensor(config.top_p))
372+
self.top_p = manipulator.duplicate(torch.tensor(config.top_p, dtype=torch.float32))
373373
self.generation_inputs.append(self.top_p)
374-
self.temperature = manipulator.duplicate(torch.tensor(config.temperature))
374+
self.temperature = manipulator.duplicate(torch.tensor(config.temperature, dtype=torch.float32))
375375
self.generation_inputs.append(self.temperature)
376-
self.top_p_min_tokens = manipulator.duplicate(torch.tensor(config.top_p_min_tokens))
376+
self.top_p_min_tokens = manipulator.duplicate(torch.tensor(config.top_p_min_tokens, dtype=torch.int32))
377377
self.generation_inputs.append(self.top_p_min_tokens)
378378
# FIXME: Use a better mechanism to pass extra params into the model
379379
ln_lm_head_params += self.generation_inputs
@@ -1107,11 +1107,11 @@ def validate_generation_configs(self, generation_config: GenerationConfig):
11071107
def update_generation_config(self, generation_config: config.GenerationConfig):
11081108
self.validate_generation_configs(generation_config)
11091109
num_cores = self.neuron_config.get_local_tp(self.tp_degree)
1110-
duplicate = lambda tensor: [torch.tensor(tensor) for _ in range(num_cores)]
1111-
ops.parallel_write(self.top_k, duplicate(generation_config.top_k))
1112-
ops.parallel_write(self.top_p, duplicate(generation_config.top_p))
1113-
ops.parallel_write(self.temperature, duplicate(generation_config.temperature))
1114-
ops.parallel_write(self.top_p_min_tokens, duplicate(generation_config.top_p_min_tokens))
1110+
duplicate = lambda tensor, dtype: [torch.tensor(tensor, dtype=dtype) for _ in range(num_cores)]
1111+
ops.parallel_write(self.top_k, duplicate(generation_config.top_k, dtype=torch.int32))
1112+
ops.parallel_write(self.top_p, duplicate(generation_config.top_p, dtype=torch.float32))
1113+
ops.parallel_write(self.temperature, duplicate(generation_config.temperature, dtype=torch.float32))
1114+
ops.parallel_write(self.top_p_min_tokens, duplicate(generation_config.top_p_min_tokens, dtype=torch.int32))
11151115

11161116

11171117
def read_n_position(hlo_module, num_inputs):

0 commit comments

Comments
 (0)