Skip to content

Commit dfa2d13

Browse files
ai-edge-botcopybara-github
authored andcommitted
Move the logic of creating engine settings and session config to CreateEngineSettings and CreateSessionConfig.
LiteRT-LM-PiperOrigin-RevId: 816766300
1 parent f0ef599 commit dfa2d13

File tree

3 files changed

+153
-129
lines changed

3 files changed

+153
-129
lines changed

android/java/com/google/ai/edge/litertlm/Message.kt

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
*/
1616
package com.google.ai.edge.litertlm
1717

18-
import android.util.Base64
18+
import kotlin.io.encoding.Base64
19+
import kotlin.io.encoding.ExperimentalEncodingApi
1920
import org.json.JSONArray
2021
import org.json.JSONObject
2122

@@ -66,38 +67,32 @@ sealed class Content {
6667
}
6768

6869
/** Image provided as raw bytes. */
70+
@OptIn(ExperimentalEncodingApi::class)
6971
data class ImageBytes(val bytes: ByteArray) : Content() {
7072
override fun toJson(): JSONObject {
71-
return JSONObject()
72-
.put("type", "image")
73-
.put("blob", Base64.encodeToString(bytes, Base64.DEFAULT))
73+
return JSONObject().put("type", "image").put("blob", Base64.encode(bytes))
7474
}
7575
}
7676

7777
/** Image provided by a file. */
7878
data class ImageFile(val absolutePath: String) : Content() {
7979
override fun toJson(): JSONObject {
80-
return JSONObject()
81-
.put("type", "image")
82-
.put("path", absolutePath)
80+
return JSONObject().put("type", "image").put("path", absolutePath)
8381
}
8482
}
8583

8684
/** Audio provided as raw bytes. */
85+
@OptIn(ExperimentalEncodingApi::class)
8786
data class AudioBytes(val bytes: ByteArray) : Content() {
8887
override fun toJson(): JSONObject {
89-
return JSONObject()
90-
.put("type", "audio")
91-
.put("blob", Base64.encodeToString(bytes, Base64.DEFAULT))
88+
return JSONObject().put("type", "audio").put("blob", Base64.encode(bytes))
9289
}
9390
}
9491

9592
/** Audio provided by a file. */
9693
data class AudioFile(val absolutePath: String) : Content() {
9794
override fun toJson(): JSONObject {
98-
return JSONObject()
99-
.put("type", "audio")
100-
.put("path", absolutePath)
95+
return JSONObject().put("type", "audio").put("path", absolutePath)
10196
}
10297
}
10398
}

runtime/engine/litert_lm_lib.cc

Lines changed: 145 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,146 @@ const absl::Duration kWaitUntilDoneTimeout = absl::Minutes(10);
7777

7878
namespace {
7979

80+
// Helper to process the sampler backend string and return a sampler backend
81+
// if possible. Otherwise, return std::nullopt.
82+
std::optional<Backend> GetSamplerBackend(const LiteRtLmSettings& settings) {
83+
const std::string& sampler_backend_str = settings.sampler_backend;
84+
if (sampler_backend_str.empty()) {
85+
return std::nullopt;
86+
}
87+
const absl::StatusOr<Backend> sampler_backend =
88+
GetBackendFromString(sampler_backend_str);
89+
if (!sampler_backend.ok()) {
90+
ABSL_LOG(WARNING) << "Ignore invalid sampler backend string: "
91+
<< sampler_backend.status();
92+
return std::nullopt;
93+
}
94+
return *sampler_backend;
95+
}
96+
97+
// Creates the EngineSettings from the LiteRtLmSettings.
98+
absl::StatusOr<EngineSettings> CreateEngineSettings(
99+
const LiteRtLmSettings& settings) {
100+
const std::string model_path = settings.model_path;
101+
if (model_path.empty()) {
102+
return absl::InvalidArgumentError("Model path is empty.");
103+
}
104+
ABSL_LOG(INFO) << "Model path: " << model_path;
105+
ASSIGN_OR_RETURN(ModelAssets model_assets, // NOLINT
106+
ModelAssets::Create(model_path));
107+
auto backend_str = settings.backend;
108+
ABSL_LOG(INFO) << "Choose backend: " << backend_str;
109+
ASSIGN_OR_RETURN(Backend backend,
110+
litert::lm::GetBackendFromString(backend_str));
111+
std::optional<Backend> vision_backend = std::nullopt;
112+
if (settings.vision_backend.has_value()) {
113+
ABSL_LOG(INFO) << "Provided vision backend: " << *settings.vision_backend;
114+
ASSIGN_OR_RETURN(vision_backend, litert::lm::GetBackendFromString(
115+
*settings.vision_backend));
116+
}
117+
std::optional<Backend> audio_backend = std::nullopt;
118+
if (settings.audio_backend.has_value()) {
119+
ABSL_LOG(INFO) << "Provided audio backend: " << *settings.audio_backend;
120+
ASSIGN_OR_RETURN(audio_backend,
121+
litert::lm::GetBackendFromString(*settings.audio_backend));
122+
}
123+
124+
ASSIGN_OR_RETURN(
125+
EngineSettings engine_settings,
126+
EngineSettings::CreateDefault(std::move(model_assets), backend,
127+
vision_backend, audio_backend));
128+
if (settings.max_num_tokens > 0) {
129+
engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(
130+
settings.max_num_tokens);
131+
}
132+
if (settings.force_f32) {
133+
engine_settings.GetMutableMainExecutorSettings().SetActivationDataType(
134+
litert::lm::ActivationDataType::FLOAT32);
135+
}
136+
if (settings.disable_cache) {
137+
engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache");
138+
}
139+
if (backend == Backend::CPU && settings.num_cpu_threads > 0) {
140+
auto& executor_settings = engine_settings.GetMutableMainExecutorSettings();
141+
ASSIGN_OR_RETURN(
142+
auto cpu_settings,
143+
executor_settings.MutableBackendConfig<litert::lm::CpuConfig>());
144+
cpu_settings.number_of_threads = settings.num_cpu_threads;
145+
executor_settings.SetBackendConfig(cpu_settings);
146+
}
147+
if (backend == Backend::GPU) {
148+
auto& executor_settings = engine_settings.GetMutableMainExecutorSettings();
149+
ASSIGN_OR_RETURN(
150+
auto gpu_settings,
151+
executor_settings.MutableBackendConfig<litert::lm::GpuConfig>());
152+
gpu_settings.external_tensor_mode = settings.gpu_external_tensor_mode;
153+
executor_settings.SetBackendConfig(gpu_settings);
154+
}
155+
const std::optional<Backend> sampler_backend = GetSamplerBackend(settings);
156+
if (sampler_backend.has_value()) {
157+
engine_settings.GetMutableMainExecutorSettings().SetSamplerBackend(
158+
*sampler_backend);
159+
}
160+
161+
AdvancedSettings advanced_settings{
162+
.prefill_batch_sizes = settings.prefill_batch_sizes,
163+
.num_output_candidates = settings.num_output_candidates,
164+
.configure_magic_numbers = settings.configure_magic_numbers,
165+
.verify_magic_numbers = settings.verify_magic_numbers,
166+
.clear_kv_cache_before_prefill = settings.clear_kv_cache_before_prefill,
167+
.num_logits_to_print_after_decode =
168+
static_cast<uint32_t>(settings.num_logits_to_print_after_decode),
169+
.gpu_madvise_original_shared_tensors =
170+
settings.gpu_madvise_original_shared_tensors,
171+
};
172+
if (advanced_settings != AdvancedSettings()) {
173+
engine_settings.GetMutableMainExecutorSettings().SetAdvancedSettings(
174+
advanced_settings);
175+
}
176+
177+
ABSL_LOG(INFO) << "executor_settings: "
178+
<< engine_settings.GetMainExecutorSettings();
179+
180+
if (engine_settings.GetVisionExecutorSettings().has_value()) {
181+
ABSL_LOG(INFO) << "vision_executor_settings: "
182+
<< engine_settings.GetVisionExecutorSettings().value();
183+
} else {
184+
ABSL_LOG(INFO) << "vision_executor_settings: not set";
185+
}
186+
if (engine_settings.GetAudioExecutorSettings().has_value()) {
187+
ABSL_LOG(INFO) << "audio_executor_settings: "
188+
<< engine_settings.GetAudioExecutorSettings().value();
189+
} else {
190+
ABSL_LOG(INFO) << "audio_executor_settings: not set";
191+
}
192+
193+
if (settings.benchmark) {
194+
if (settings.multi_turns) {
195+
ABSL_LOG(FATAL)
196+
<< "Benchmarking with multi-turns input is not supported.";
197+
}
198+
199+
litert::lm::proto::BenchmarkParams benchmark_params;
200+
benchmark_params.set_num_prefill_tokens(settings.benchmark_prefill_tokens);
201+
benchmark_params.set_num_decode_tokens(settings.benchmark_decode_tokens);
202+
engine_settings.GetMutableBenchmarkParams() = benchmark_params;
203+
}
204+
205+
return engine_settings;
206+
}
207+
208+
// Creates the SessionConfig from the LiteRtLmSettings.
209+
SessionConfig CreateSessionConfig(const LiteRtLmSettings& settings) {
210+
// Set the session config.
211+
auto session_config = litert::lm::SessionConfig::CreateDefault();
212+
session_config.SetNumOutputCandidates(settings.num_output_candidates);
213+
const std::optional<Backend> sampler_backend = GetSamplerBackend(settings);
214+
if (sampler_backend.has_value()) {
215+
session_config.SetSamplerBackend(*sampler_backend);
216+
}
217+
return session_config;
218+
}
219+
80220
absl::Status PrintJsonMessage(const JsonMessage& message,
81221
std::stringstream& captured_output,
82222
bool streaming = false) {
@@ -282,10 +422,6 @@ void RunScoreText(litert::lm::Engine* llm, litert::lm::Engine::Session* session,
282422
} // namespace
283423

284424
absl::Status RunLiteRtLm(const LiteRtLmSettings& settings) {
285-
const std::string model_path = settings.model_path;
286-
if (model_path.empty()) {
287-
return absl::InvalidArgumentError("Model path is empty.");
288-
}
289425

290426
std::unique_ptr<tflite::profiling::memory::MemoryUsageMonitor> mem_monitor;
291427
if (settings.report_peak_memory_footprint) {
@@ -294,122 +430,16 @@ absl::Status RunLiteRtLm(const LiteRtLmSettings& settings) {
294430
kMemoryCheckIntervalMs);
295431
mem_monitor->Start();
296432
}
297-
ABSL_LOG(INFO) << "Model path: " << model_path;
298-
ASSIGN_OR_RETURN(ModelAssets model_assets, // NOLINT
299-
ModelAssets::Create(model_path));
300-
auto backend_str = settings.backend;
301-
ABSL_LOG(INFO) << "Choose backend: " << backend_str;
302-
ASSIGN_OR_RETURN(Backend backend,
303-
litert::lm::GetBackendFromString(backend_str));
304-
std::optional<Backend> vision_backend = std::nullopt;
305-
if (settings.vision_backend.has_value()) {
306-
ABSL_LOG(INFO) << "Provided vision backend: " << *settings.vision_backend;
307-
ASSIGN_OR_RETURN(vision_backend, litert::lm::GetBackendFromString(
308-
*settings.vision_backend));
309-
}
310-
std::optional<Backend> audio_backend = std::nullopt;
311-
if (settings.audio_backend.has_value()) {
312-
ABSL_LOG(INFO) << "Provided audio backend: " << *settings.audio_backend;
313-
ASSIGN_OR_RETURN(audio_backend,
314-
litert::lm::GetBackendFromString(*settings.audio_backend));
315-
}
316-
317-
ASSIGN_OR_RETURN(
318-
EngineSettings engine_settings,
319-
EngineSettings::CreateDefault(std::move(model_assets), backend,
320-
vision_backend, audio_backend));
321-
if (settings.max_num_tokens > 0) {
322-
engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(
323-
settings.max_num_tokens);
324-
}
325-
if (settings.force_f32) {
326-
engine_settings.GetMutableMainExecutorSettings().SetActivationDataType(
327-
litert::lm::ActivationDataType::FLOAT32);
328-
}
329-
if (settings.disable_cache) {
330-
engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache");
331-
}
332-
if (backend == Backend::CPU && settings.num_cpu_threads > 0) {
333-
auto& executor_settings = engine_settings.GetMutableMainExecutorSettings();
334-
ASSIGN_OR_RETURN(
335-
auto cpu_settings,
336-
executor_settings.MutableBackendConfig<litert::lm::CpuConfig>());
337-
cpu_settings.number_of_threads = settings.num_cpu_threads;
338-
executor_settings.SetBackendConfig(cpu_settings);
339-
}
340-
if (backend == Backend::GPU) {
341-
auto& executor_settings = engine_settings.GetMutableMainExecutorSettings();
342-
ASSIGN_OR_RETURN(
343-
auto gpu_settings,
344-
executor_settings.MutableBackendConfig<litert::lm::GpuConfig>());
345-
gpu_settings.external_tensor_mode = settings.gpu_external_tensor_mode;
346-
executor_settings.SetBackendConfig(gpu_settings);
347-
}
348-
auto session_config = litert::lm::SessionConfig::CreateDefault();
349-
session_config.SetNumOutputCandidates(settings.num_output_candidates);
350-
auto sampler_backend_str = settings.sampler_backend;
351-
if (!sampler_backend_str.empty()) {
352-
auto sampler_backend =
353-
litert::lm::GetBackendFromString(settings.sampler_backend);
354-
if (!sampler_backend.ok()) {
355-
ABSL_LOG(WARNING) << "Ignore invalid sampler backend string: "
356-
<< sampler_backend.status();
357-
} else {
358-
session_config.SetSamplerBackend(*sampler_backend);
359-
auto& executor_settings =
360-
engine_settings.GetMutableMainExecutorSettings();
361-
executor_settings.SetSamplerBackend(*sampler_backend);
362-
}
363-
}
364-
365-
AdvancedSettings advanced_settings{
366-
.prefill_batch_sizes = settings.prefill_batch_sizes,
367-
.num_output_candidates = session_config.GetNumOutputCandidates(),
368-
.configure_magic_numbers = settings.configure_magic_numbers,
369-
.verify_magic_numbers = settings.verify_magic_numbers,
370-
.clear_kv_cache_before_prefill = settings.clear_kv_cache_before_prefill,
371-
.num_logits_to_print_after_decode =
372-
static_cast<uint32_t>(settings.num_logits_to_print_after_decode),
373-
.gpu_madvise_original_shared_tensors =
374-
settings.gpu_madvise_original_shared_tensors,
375-
};
376-
if (advanced_settings != AdvancedSettings()) {
377-
engine_settings.GetMutableMainExecutorSettings().SetAdvancedSettings(
378-
advanced_settings);
379-
}
380-
381-
ABSL_LOG(INFO) << "executor_settings: "
382-
<< engine_settings.GetMainExecutorSettings();
383-
384-
if (engine_settings.GetVisionExecutorSettings().has_value()) {
385-
ABSL_LOG(INFO) << "vision_executor_settings: "
386-
<< engine_settings.GetVisionExecutorSettings().value();
387-
} else {
388-
ABSL_LOG(INFO) << "vision_executor_settings: not set";
389-
}
390-
if (engine_settings.GetAudioExecutorSettings().has_value()) {
391-
ABSL_LOG(INFO) << "audio_executor_settings: "
392-
<< engine_settings.GetAudioExecutorSettings().value();
393-
} else {
394-
ABSL_LOG(INFO) << "audio_executor_settings: not set";
395-
}
396-
397-
if (settings.benchmark) {
398-
if (settings.multi_turns) {
399-
ABSL_LOG(FATAL)
400-
<< "Benchmarking with multi-turns input is not supported.";
401-
}
402-
403-
litert::lm::proto::BenchmarkParams benchmark_params;
404-
benchmark_params.set_num_prefill_tokens(settings.benchmark_prefill_tokens);
405-
benchmark_params.set_num_decode_tokens(settings.benchmark_decode_tokens);
406-
engine_settings.GetMutableBenchmarkParams() = benchmark_params;
407-
}
408433

434+
// Get the engine settings and create the engine.
435+
ASSIGN_OR_RETURN(EngineSettings engine_settings,
436+
CreateEngineSettings(settings));
409437
ABSL_LOG(INFO) << "Creating engine";
410438
ASSIGN_OR_RETURN(auto engine,
411439
litert::lm::Engine::CreateEngine(std::move(engine_settings),
412440
settings.input_prompt));
441+
// Get the session config.
442+
const SessionConfig session_config = CreateSessionConfig(settings);
413443

414444
// Session and Conversation are mutually exclusive. Only when
415445
// settings.score_target_text is set, we will create a Session to run the

runtime/engine/litert_lm_lib.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <optional>
1919
#include <set>
2020
#include <string>
21-
#include <vector>
2221

2322
#include "absl/status/status.h" // from @com_google_absl
2423

0 commit comments

Comments
 (0)