Skip to content

Commit fd1a28b

Browse files
aws-patlangeaws-yishanm
authored andcommitted
[generation] Fix per_batch_line sampling param lookup for CB
GitOrigin-RevId: 5039ec6b067d1431201563659bb05b9c90e79bb0
1 parent 51e8784 commit fd1a28b

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

src/transformers_neuronx/decoder.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,6 @@ def _hlo_eagle_draft_unroll(self, hidden, tensors, layers_caches, layers_weights
698698
return logits, hidden, out_caches
699699

700700
def _hlo_fully_unrolled(self, n_positions, batch_size):
701-
702701
self.builder.n_positions = n_positions
703702
if self.neuron_config.optimized_paged_attention and self.n_active_tokens == 1:
704703
self.builder.num_active_blocks = batch_size
@@ -733,7 +732,7 @@ def fully_unrolled(scribe):
733732
else:
734733
logits, out_caches = self._hlo_unroll(hidden, tensors, in_caches, layers_weights, pre_layer_params, lm_head_params)
735734
self._hlo_cache_aliases(in_caches, out_caches)
736-
output = self._hlo_generation(logits, generation_params)
735+
output = self._hlo_generation(logits, generation_params, start_ids=tensors[1])
737736

738737
# Set the output
739738
out_caches = itertools.chain(*out_caches)
@@ -1006,11 +1005,12 @@ def ln_lm_head(scribe):
10061005
next_tok_id = scribe.s32[batch_size].Parameter(parameter_number=1)
10071006
else:
10081007
next_tok_id = scribe.s32[1].Parameter(parameter_number=1)
1009-
param_builder = DecoderParameterBuilder(scribe, 2)
1008+
start_ids = scribe.s32[batch_size].Parameter(parameter_number=2)
1009+
param_builder = DecoderParameterBuilder(scribe, 3)
10101010
ln_f_weight, ln_f_bias, head_weight, head_bias = self._hlo_lm_head_params(param_builder)
1011-
gneration_params = self._hlo_generation_params(param_builder)
1011+
generation_params = self._hlo_generation_params(param_builder)
10121012
logits = self.ln_lm_head_builder(hidden, next_tok_id, ln_f_weight, ln_f_bias, head_weight, head_bias, return_all_outputs=self.return_all_outputs)
1013-
output = self._hlo_generation(logits, gneration_params)
1013+
output = self._hlo_generation(logits, generation_params, start_ids=start_ids)
10141014
if self.neuron_config.log_softmax_scores:
10151015
logits, scores = self._hlo_post_layer(logits)
10161016
outputs = [logits, scores]
@@ -1034,7 +1034,7 @@ def _hlo_generation_params(self, param_builder):
10341034
params.append(param)
10351035
return params
10361036

1037-
def _hlo_generation(self, logits, params, early_return=False, return_probs=False):
1037+
def _hlo_generation(self, logits, params, early_return=False, return_probs=False, start_ids=None):
10381038
generation_config = self.neuron_config.on_device_generation
10391039
if generation_config is None:
10401040
return logits
@@ -1045,13 +1045,16 @@ def _hlo_generation(self, logits, params, early_return=False, return_probs=False
10451045
self.neuron_config.on_device_generation.top_p = top_p
10461046
self.neuron_config.on_device_generation.temperature = temperature
10471047
self.neuron_config.on_device_generation.top_p_min_tokens = top_p_min_tokens
1048+
1049+
seq_ids = start_ids if self.neuron_config.continuous_batching is not None else None
10481050
return generation.generate(
10491051
logits,
10501052
logits_indices,
10511053
config=generation_config,
10521054
tp_degree=self.tp_degree,
10531055
early_return=early_return,
10541056
return_probs=return_probs,
1057+
seq_ids=seq_ids,
10551058
)
10561059

10571060
# Mainly used for serialization purposes.
@@ -2458,10 +2461,12 @@ def setup(self, layers, pre_layer_params, ln_lm_head_params):
24582461

24592462
hidden_buffers = list()
24602463
last_token_id_buffers = list()
2464+
start_ids_buffers = list()
24612465
for input_buffer in self.input_buffers:
2462-
hidden_buffer, _, _, last_token_id_buffer, *_ = input_buffer
2466+
hidden_buffer, _, start_ids_buffer, last_token_id_buffer, *_ = input_buffer
24632467
hidden_buffers.append(hidden_buffer)
24642468
last_token_id_buffers.append(last_token_id_buffer)
2469+
start_ids_buffers.append(start_ids_buffer)
24652470

24662471
multi_layer_starts = range(0, len(layers), self.unroll)
24672472
multi_layers = [layers[start:start+self.unroll] for start in multi_layer_starts]
@@ -2480,7 +2485,7 @@ def setup(self, layers, pre_layer_params, ln_lm_head_params):
24802485
if self.neuron_config.is_valid_lm_head():
24812486
for head_idx in range(0,len(self.ln_lm_head_kernels)):
24822487
output_tensors = [*self.logits_buffer[head_idx]] if self.neuron_config.log_softmax_scores or self.neuron_config.is_eagle_target else [self.logits_buffer[head_idx]]
2483-
self.ln_lm_head_memories[head_idx].setup([hidden_buffers[head_idx], last_token_id_buffers[head_idx], *ln_lm_head_params], output_tensors)
2488+
self.ln_lm_head_memories[head_idx].setup([hidden_buffers[head_idx], last_token_id_buffers[head_idx], start_ids_buffers[head_idx], *ln_lm_head_params], output_tensors)
24842489
self.ln_lm_head_kernels[head_idx].build()
24852490
self.ln_lm_head_kernels[head_idx].load()
24862491

src/transformers_neuronx/layers/generation.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
from transformers_neuronx import hlo, config
1616

17-
def generate(logits, logits_indices, config: config.GenerationConfig, tp_degree=1, early_return=False, return_probs=False):
17+
def generate(logits, logits_indices, config: config.GenerationConfig, tp_degree=1, early_return=False, return_probs=False, seq_ids=None):
1818
logits = mask_logits(logits, logits_indices, config.vocab_size)
1919
if not config.dynamic and not config.do_sample:
2020
tokens = greedy_search(logits, tp_degree=tp_degree)
@@ -62,10 +62,9 @@ def generate(logits, logits_indices, config: config.GenerationConfig, tp_degree=
6262
logits_slice = hlo.slice_along(logits, 0, start=batch_line, limit=batch_line+1)
6363
indices_slice = None if indices is None else hlo.slice_along(indices, 0, start=batch_line, limit=batch_line+1)
6464

65-
batch_line_top_k = config.top_k if hlo._is_hlo_scalar(config.top_k) else hlo.get_hlo_scalar_by_index(config.top_k, batch_line)
66-
batch_line_top_p = config.top_p if hlo._is_hlo_scalar(config.top_p) else hlo.get_hlo_scalar_by_index(config.top_p, batch_line)
67-
batch_line_temperature = config.temperature if hlo._is_hlo_scalar(config.temperature) else hlo.get_hlo_scalar_by_index(config.temperature, batch_line)
68-
batch_line_top_p_min_tokens = config.top_p_min_tokens if hlo._is_hlo_scalar(config.top_p_min_tokens) else hlo.get_hlo_scalar_by_index(config.top_p_min_tokens, batch_line)
65+
batch_line_top_k, batch_line_top_p, batch_line_temperature, batch_line_top_p_min_tokens = sampling_params_for_batch_line(
66+
seq_ids, batch_line, config
67+
)
6968

7069
token = sample(
7170
logits_slice,
@@ -104,6 +103,22 @@ def generate(logits, logits_indices, config: config.GenerationConfig, tp_degree=
104103
returned_tokens = hlo.concatenate(tokens, dimension=0)
105104
return returned_tokens
106105

106+
107+
def sampling_params_for_batch_line(seq_ids, batch_line: int, config: config.GenerationConfig):
108+
if seq_ids is not None:
109+
seq_id_for_batch = hlo.slice_along(seq_ids, 0, start=batch_line, limit=batch_line+1)
110+
batch_line_top_k = hlo.reshape(hlo.index_select(config.top_k, 0, seq_id_for_batch), [])
111+
batch_line_top_p = hlo.reshape(hlo.index_select(config.top_p, 0, seq_id_for_batch), [])
112+
batch_line_temperature = hlo.reshape(hlo.index_select(config.temperature, 0, seq_id_for_batch), [])
113+
batch_line_top_p_min_tokens = hlo.reshape(hlo.index_select(config.top_p_min_tokens, 0, seq_id_for_batch), [])
114+
else:
115+
batch_line_top_k = config.top_k if hlo._is_hlo_scalar(config.top_k) else hlo.get_hlo_scalar_by_index(config.top_k, batch_line)
116+
batch_line_top_p = config.top_p if hlo._is_hlo_scalar(config.top_p) else hlo.get_hlo_scalar_by_index(config.top_p, batch_line)
117+
batch_line_temperature = config.temperature if hlo._is_hlo_scalar(config.temperature) else hlo.get_hlo_scalar_by_index(config.temperature, batch_line)
118+
batch_line_top_p_min_tokens = config.top_p_min_tokens if hlo._is_hlo_scalar(config.top_p_min_tokens) else hlo.get_hlo_scalar_by_index(config.top_p_min_tokens, batch_line)
119+
return (batch_line_top_k, batch_line_top_p, batch_line_temperature, batch_line_top_p_min_tokens)
120+
121+
107122
def mask_logits(logits, indices, model_vocab_size):
108123
vocab_size, n_active_tokens, _ = logits.sizes
109124
indices_br = hlo.broadcast(indices, (logits.sizes), broadcast_dimensions=(0,))

0 commit comments

Comments
 (0)