Skip to content

Commit dfad938

Browse files
protobird-gitcopybara-github
authored andcommitted
Clean up LlmLiteRtCompiledModelExecutor.
Remove decoded_logits that is already kept in decode_output_buffers_["logits"] LiteRT-LM-PiperOrigin-RevId: 820730190
1 parent 398f019 commit dfad938

File tree

2 files changed

+37
-41
lines changed

2 files changed

+37
-41
lines changed

runtime/executor/llm_litert_compiled_model_executor.cc

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -549,49 +549,49 @@ absl::Status LlmLiteRtCompiledModelExecutor::ConsumePendingOrAddProcessedToken(
549549
absl::Status LlmLiteRtCompiledModelExecutor::DecodeInternal(
550550
const int step, const std::shared_ptr<TokenData> token,
551551
TensorBuffer& output_logits) {
552-
{
553-
const bool use_token_as_lookup = !signatures_.input_tokens.empty();
554-
const bool use_per_layer_embedding =
555-
signatures_.input_per_layer_embeddings.has_value();
556-
557-
// Fill the input buffers with scoped locks.
558-
if (use_token_as_lookup) {
559-
RETURN_IF_ERROR(FillInputBufferWithToken(
560-
token, decode_input_buffers_[signatures_.input_tokens]));
561-
} else {
562-
if (!signatures_.input_embeddings.has_value()) {
563-
return absl::InvalidArgumentError(
564-
"Input tokens or embeddings must be provided.");
565-
}
552+
const bool use_token_as_lookup = !signatures_.input_tokens.empty();
553+
const bool use_per_layer_embedding =
554+
signatures_.input_per_layer_embeddings.has_value();
555+
556+
// Fill the input buffers with scoped locks.
557+
if (use_token_as_lookup) {
558+
RETURN_IF_ERROR(FillInputBufferWithToken(
559+
token, decode_input_buffers_[signatures_.input_tokens]));
560+
} else {
561+
if (!signatures_.input_embeddings.has_value()) {
562+
return absl::InvalidArgumentError(
563+
"Input tokens or embeddings must be provided.");
564+
}
565+
RETURN_IF_ERROR(FillInputBufferWithToken(
566+
token, decode_input_buffers_[signatures_.input_embeddings.value()]));
567+
if (use_per_layer_embedding) {
566568
RETURN_IF_ERROR(FillInputBufferWithToken(
567-
token, decode_input_buffers_[signatures_.input_embeddings.value()]));
568-
if (use_per_layer_embedding) {
569-
RETURN_IF_ERROR(FillInputBufferWithToken(
570-
token,
571-
decode_input_buffers_[signatures_.input_per_layer_embeddings
572-
.value()],
573-
/*is_per_layer_embedding=*/true));
574-
}
569+
token,
570+
decode_input_buffers_[signatures_.input_per_layer_embeddings.value()],
571+
/*is_per_layer_embedding=*/true));
575572
}
576-
auto& decode_input_pos_buffer =
577-
decode_input_buffers_[signatures_.input_positions];
573+
}
574+
575+
{
578576
LITERT_ASSIGN_OR_RETURN(
579577
auto decode_input_pos_lock_and_addr,
580-
TensorBufferScopedLock::Create(decode_input_pos_buffer,
581-
TensorBuffer::LockMode::kWrite));
578+
TensorBufferScopedLock::Create(
579+
decode_input_buffers_[signatures_.input_positions],
580+
TensorBuffer::LockMode::kWrite));
582581
auto* decode_input_pos_ptr =
583582
static_cast<int32_t*>(decode_input_pos_lock_and_addr.second);
584-
if (signatures_.input_attn_mask.has_value()) {
585-
RETURN_IF_ERROR(InitializeAttentionMask(
586-
decode_input_buffers_[signatures_.input_attn_mask.value()],
587-
IsCalculationPrecisionF16()));
588-
RETURN_IF_ERROR(FillAttentionMask(
589-
decode_input_buffers_[signatures_.input_attn_mask.value()], step,
590-
/*steps=*/1));
591-
}
592583
decode_input_pos_ptr[0] = step;
593584
}
594585

586+
if (signatures_.input_attn_mask.has_value()) {
587+
RETURN_IF_ERROR(InitializeAttentionMask(
588+
decode_input_buffers_[signatures_.input_attn_mask.value()],
589+
IsCalculationPrecisionF16()));
590+
RETURN_IF_ERROR(FillAttentionMask(
591+
decode_input_buffers_[signatures_.input_attn_mask.value()], step,
592+
/*steps=*/1));
593+
}
594+
595595
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>
596596
decode_input_buffers;
597597
for (const auto& [input_name, input_buffer] : decode_input_buffers_) {
@@ -633,13 +633,13 @@ absl::Status LlmLiteRtCompiledModelExecutor::Decode(
633633
::litert::TensorBuffer& output_tokens,
634634
const ExecutorDecodeParams& decode_params) {
635635

636-
ASSIGN_OR_RETURN(decoded_logits_,
636+
ASSIGN_OR_RETURN(auto decoded_logits,
637637
DecodeLogits(ExecutorInputs(), decode_params));
638-
LITERT_ASSIGN_OR_RETURN(auto size, decoded_logits_.PackedSize());
638+
LITERT_ASSIGN_OR_RETURN(auto size, decoded_logits.PackedSize());
639639
if (decoded_logits_vector_.empty()) {
640640
decoded_logits_vector_ = std::vector<float>(size / sizeof(float));
641641
}
642-
RETURN_IF_ERROR(SampleLogits(decoded_logits_, output_tokens));
642+
RETURN_IF_ERROR(SampleLogits(decoded_logits, output_tokens));
643643

644644
// Read the first output token for the next input token id.
645645
bool reset_output_token = false;

runtime/executor/llm_litert_compiled_model_executor.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,6 @@ class LlmLiteRtCompiledModelExecutor : public LlmExecutor {
240240
// track of the pending input token, if any.
241241
ProcessedTokens processed_tokens_;
242242

243-
// A tensor buffer to store the logits decoded before sampling the final
244-
// tokens. It's to avoid creating a new tensor buffer for each Decode() call.
245-
::litert::TensorBuffer decoded_logits_;
246-
247243
// A vector to store the logits decoded before sampling the final tokens.
248244
// It's to avoid creating a new vector for each Decode() call.
249245
std::vector<float> decoded_logits_vector_;

0 commit comments

Comments
 (0)