@@ -77,6 +77,146 @@ const absl::Duration kWaitUntilDoneTimeout = absl::Minutes(10);
7777
7878namespace {
7979
80+ // Helper to process the sampler backend string and return a sampler backend
81+ // if possible. Otherwise, return std::nullopt.
82+ std::optional<Backend> GetSamplerBackend (const LiteRtLmSettings& settings) {
83+ const std::string& sampler_backend_str = settings.sampler_backend ;
84+ if (sampler_backend_str.empty ()) {
85+ return std::nullopt ;
86+ }
87+ const absl::StatusOr<Backend> sampler_backend =
88+ GetBackendFromString (sampler_backend_str);
89+ if (!sampler_backend.ok ()) {
90+ ABSL_LOG (WARNING) << " Ignore invalid sampler backend string: "
91+ << sampler_backend.status ();
92+ return std::nullopt ;
93+ }
94+ return *sampler_backend;
95+ }
96+
97+ // Creates the EngineSettings from the LiteRtLmSettings.
98+ absl::StatusOr<EngineSettings> CreateEngineSettings (
99+ const LiteRtLmSettings& settings) {
100+ const std::string model_path = settings.model_path ;
101+ if (model_path.empty ()) {
102+ return absl::InvalidArgumentError (" Model path is empty." );
103+ }
104+ ABSL_LOG (INFO) << " Model path: " << model_path;
105+ ASSIGN_OR_RETURN (ModelAssets model_assets, // NOLINT
106+ ModelAssets::Create (model_path));
107+ auto backend_str = settings.backend ;
108+ ABSL_LOG (INFO) << " Choose backend: " << backend_str;
109+ ASSIGN_OR_RETURN (Backend backend,
110+ litert::lm::GetBackendFromString (backend_str));
111+ std::optional<Backend> vision_backend = std::nullopt ;
112+ if (settings.vision_backend .has_value ()) {
113+ ABSL_LOG (INFO) << " Provided vision backend: " << *settings.vision_backend ;
114+ ASSIGN_OR_RETURN (vision_backend, litert::lm::GetBackendFromString (
115+ *settings.vision_backend ));
116+ }
117+ std::optional<Backend> audio_backend = std::nullopt ;
118+ if (settings.audio_backend .has_value ()) {
119+ ABSL_LOG (INFO) << " Provided audio backend: " << *settings.audio_backend ;
120+ ASSIGN_OR_RETURN (audio_backend,
121+ litert::lm::GetBackendFromString (*settings.audio_backend ));
122+ }
123+
124+ ASSIGN_OR_RETURN (
125+ EngineSettings engine_settings,
126+ EngineSettings::CreateDefault (std::move (model_assets), backend,
127+ vision_backend, audio_backend));
128+ if (settings.max_num_tokens > 0 ) {
129+ engine_settings.GetMutableMainExecutorSettings ().SetMaxNumTokens (
130+ settings.max_num_tokens );
131+ }
132+ if (settings.force_f32 ) {
133+ engine_settings.GetMutableMainExecutorSettings ().SetActivationDataType (
134+ litert::lm::ActivationDataType::FLOAT32);
135+ }
136+ if (settings.disable_cache ) {
137+ engine_settings.GetMutableMainExecutorSettings ().SetCacheDir (" :nocache" );
138+ }
139+ if (backend == Backend::CPU && settings.num_cpu_threads > 0 ) {
140+ auto & executor_settings = engine_settings.GetMutableMainExecutorSettings ();
141+ ASSIGN_OR_RETURN (
142+ auto cpu_settings,
143+ executor_settings.MutableBackendConfig <litert::lm::CpuConfig>());
144+ cpu_settings.number_of_threads = settings.num_cpu_threads ;
145+ executor_settings.SetBackendConfig (cpu_settings);
146+ }
147+ if (backend == Backend::GPU) {
148+ auto & executor_settings = engine_settings.GetMutableMainExecutorSettings ();
149+ ASSIGN_OR_RETURN (
150+ auto gpu_settings,
151+ executor_settings.MutableBackendConfig <litert::lm::GpuConfig>());
152+ gpu_settings.external_tensor_mode = settings.gpu_external_tensor_mode ;
153+ executor_settings.SetBackendConfig (gpu_settings);
154+ }
155+ const std::optional<Backend> sampler_backend = GetSamplerBackend (settings);
156+ if (sampler_backend.has_value ()) {
157+ engine_settings.GetMutableMainExecutorSettings ().SetSamplerBackend (
158+ *sampler_backend);
159+ }
160+
161+ AdvancedSettings advanced_settings{
162+ .prefill_batch_sizes = settings.prefill_batch_sizes ,
163+ .num_output_candidates = settings.num_output_candidates ,
164+ .configure_magic_numbers = settings.configure_magic_numbers ,
165+ .verify_magic_numbers = settings.verify_magic_numbers ,
166+ .clear_kv_cache_before_prefill = settings.clear_kv_cache_before_prefill ,
167+ .num_logits_to_print_after_decode =
168+ static_cast <uint32_t >(settings.num_logits_to_print_after_decode ),
169+ .gpu_madvise_original_shared_tensors =
170+ settings.gpu_madvise_original_shared_tensors ,
171+ };
172+ if (advanced_settings != AdvancedSettings ()) {
173+ engine_settings.GetMutableMainExecutorSettings ().SetAdvancedSettings (
174+ advanced_settings);
175+ }
176+
177+ ABSL_LOG (INFO) << " executor_settings: "
178+ << engine_settings.GetMainExecutorSettings ();
179+
180+ if (engine_settings.GetVisionExecutorSettings ().has_value ()) {
181+ ABSL_LOG (INFO) << " vision_executor_settings: "
182+ << engine_settings.GetVisionExecutorSettings ().value ();
183+ } else {
184+ ABSL_LOG (INFO) << " vision_executor_settings: not set" ;
185+ }
186+ if (engine_settings.GetAudioExecutorSettings ().has_value ()) {
187+ ABSL_LOG (INFO) << " audio_executor_settings: "
188+ << engine_settings.GetAudioExecutorSettings ().value ();
189+ } else {
190+ ABSL_LOG (INFO) << " audio_executor_settings: not set" ;
191+ }
192+
193+ if (settings.benchmark ) {
194+ if (settings.multi_turns ) {
195+ ABSL_LOG (FATAL)
196+ << " Benchmarking with multi-turns input is not supported." ;
197+ }
198+
199+ litert::lm::proto::BenchmarkParams benchmark_params;
200+ benchmark_params.set_num_prefill_tokens (settings.benchmark_prefill_tokens );
201+ benchmark_params.set_num_decode_tokens (settings.benchmark_decode_tokens );
202+ engine_settings.GetMutableBenchmarkParams () = benchmark_params;
203+ }
204+
205+ return engine_settings;
206+ }
207+
208+ // Creates the SessionConfig from the LiteRtLmSettings.
209+ SessionConfig CreateSessionConfig (const LiteRtLmSettings& settings) {
210+ // Set the session config.
211+ auto session_config = litert::lm::SessionConfig::CreateDefault ();
212+ session_config.SetNumOutputCandidates (settings.num_output_candidates );
213+ const std::optional<Backend> sampler_backend = GetSamplerBackend (settings);
214+ if (sampler_backend.has_value ()) {
215+ session_config.SetSamplerBackend (*sampler_backend);
216+ }
217+ return session_config;
218+ }
219+
80220absl::Status PrintJsonMessage (const JsonMessage& message,
81221 std::stringstream& captured_output,
82222 bool streaming = false ) {
@@ -282,10 +422,6 @@ void RunScoreText(litert::lm::Engine* llm, litert::lm::Engine::Session* session,
282422} // namespace
283423
284424absl::Status RunLiteRtLm (const LiteRtLmSettings& settings) {
285- const std::string model_path = settings.model_path ;
286- if (model_path.empty ()) {
287- return absl::InvalidArgumentError (" Model path is empty." );
288- }
289425
290426 std::unique_ptr<tflite::profiling::memory::MemoryUsageMonitor> mem_monitor;
291427 if (settings.report_peak_memory_footprint ) {
@@ -294,122 +430,16 @@ absl::Status RunLiteRtLm(const LiteRtLmSettings& settings) {
294430 kMemoryCheckIntervalMs );
295431 mem_monitor->Start ();
296432 }
297- ABSL_LOG (INFO) << " Model path: " << model_path;
298- ASSIGN_OR_RETURN (ModelAssets model_assets, // NOLINT
299- ModelAssets::Create (model_path));
300- auto backend_str = settings.backend ;
301- ABSL_LOG (INFO) << " Choose backend: " << backend_str;
302- ASSIGN_OR_RETURN (Backend backend,
303- litert::lm::GetBackendFromString (backend_str));
304- std::optional<Backend> vision_backend = std::nullopt ;
305- if (settings.vision_backend .has_value ()) {
306- ABSL_LOG (INFO) << " Provided vision backend: " << *settings.vision_backend ;
307- ASSIGN_OR_RETURN (vision_backend, litert::lm::GetBackendFromString (
308- *settings.vision_backend ));
309- }
310- std::optional<Backend> audio_backend = std::nullopt ;
311- if (settings.audio_backend .has_value ()) {
312- ABSL_LOG (INFO) << " Provided audio backend: " << *settings.audio_backend ;
313- ASSIGN_OR_RETURN (audio_backend,
314- litert::lm::GetBackendFromString (*settings.audio_backend ));
315- }
316-
317- ASSIGN_OR_RETURN (
318- EngineSettings engine_settings,
319- EngineSettings::CreateDefault (std::move (model_assets), backend,
320- vision_backend, audio_backend));
321- if (settings.max_num_tokens > 0 ) {
322- engine_settings.GetMutableMainExecutorSettings ().SetMaxNumTokens (
323- settings.max_num_tokens );
324- }
325- if (settings.force_f32 ) {
326- engine_settings.GetMutableMainExecutorSettings ().SetActivationDataType (
327- litert::lm::ActivationDataType::FLOAT32);
328- }
329- if (settings.disable_cache ) {
330- engine_settings.GetMutableMainExecutorSettings ().SetCacheDir (" :nocache" );
331- }
332- if (backend == Backend::CPU && settings.num_cpu_threads > 0 ) {
333- auto & executor_settings = engine_settings.GetMutableMainExecutorSettings ();
334- ASSIGN_OR_RETURN (
335- auto cpu_settings,
336- executor_settings.MutableBackendConfig <litert::lm::CpuConfig>());
337- cpu_settings.number_of_threads = settings.num_cpu_threads ;
338- executor_settings.SetBackendConfig (cpu_settings);
339- }
340- if (backend == Backend::GPU) {
341- auto & executor_settings = engine_settings.GetMutableMainExecutorSettings ();
342- ASSIGN_OR_RETURN (
343- auto gpu_settings,
344- executor_settings.MutableBackendConfig <litert::lm::GpuConfig>());
345- gpu_settings.external_tensor_mode = settings.gpu_external_tensor_mode ;
346- executor_settings.SetBackendConfig (gpu_settings);
347- }
348- auto session_config = litert::lm::SessionConfig::CreateDefault ();
349- session_config.SetNumOutputCandidates (settings.num_output_candidates );
350- auto sampler_backend_str = settings.sampler_backend ;
351- if (!sampler_backend_str.empty ()) {
352- auto sampler_backend =
353- litert::lm::GetBackendFromString (settings.sampler_backend );
354- if (!sampler_backend.ok ()) {
355- ABSL_LOG (WARNING) << " Ignore invalid sampler backend string: "
356- << sampler_backend.status ();
357- } else {
358- session_config.SetSamplerBackend (*sampler_backend);
359- auto & executor_settings =
360- engine_settings.GetMutableMainExecutorSettings ();
361- executor_settings.SetSamplerBackend (*sampler_backend);
362- }
363- }
364-
365- AdvancedSettings advanced_settings{
366- .prefill_batch_sizes = settings.prefill_batch_sizes ,
367- .num_output_candidates = session_config.GetNumOutputCandidates (),
368- .configure_magic_numbers = settings.configure_magic_numbers ,
369- .verify_magic_numbers = settings.verify_magic_numbers ,
370- .clear_kv_cache_before_prefill = settings.clear_kv_cache_before_prefill ,
371- .num_logits_to_print_after_decode =
372- static_cast <uint32_t >(settings.num_logits_to_print_after_decode ),
373- .gpu_madvise_original_shared_tensors =
374- settings.gpu_madvise_original_shared_tensors ,
375- };
376- if (advanced_settings != AdvancedSettings ()) {
377- engine_settings.GetMutableMainExecutorSettings ().SetAdvancedSettings (
378- advanced_settings);
379- }
380-
381- ABSL_LOG (INFO) << " executor_settings: "
382- << engine_settings.GetMainExecutorSettings ();
383-
384- if (engine_settings.GetVisionExecutorSettings ().has_value ()) {
385- ABSL_LOG (INFO) << " vision_executor_settings: "
386- << engine_settings.GetVisionExecutorSettings ().value ();
387- } else {
388- ABSL_LOG (INFO) << " vision_executor_settings: not set" ;
389- }
390- if (engine_settings.GetAudioExecutorSettings ().has_value ()) {
391- ABSL_LOG (INFO) << " audio_executor_settings: "
392- << engine_settings.GetAudioExecutorSettings ().value ();
393- } else {
394- ABSL_LOG (INFO) << " audio_executor_settings: not set" ;
395- }
396-
397- if (settings.benchmark ) {
398- if (settings.multi_turns ) {
399- ABSL_LOG (FATAL)
400- << " Benchmarking with multi-turns input is not supported." ;
401- }
402-
403- litert::lm::proto::BenchmarkParams benchmark_params;
404- benchmark_params.set_num_prefill_tokens (settings.benchmark_prefill_tokens );
405- benchmark_params.set_num_decode_tokens (settings.benchmark_decode_tokens );
406- engine_settings.GetMutableBenchmarkParams () = benchmark_params;
407- }
408433
434+ // Get the engine settings and create the engine.
435+ ASSIGN_OR_RETURN (EngineSettings engine_settings,
436+ CreateEngineSettings (settings));
409437 ABSL_LOG (INFO) << " Creating engine" ;
410438 ASSIGN_OR_RETURN (auto engine,
411439 litert::lm::Engine::CreateEngine (std::move (engine_settings),
412440 settings.input_prompt ));
441+ // Get the session config.
442+ const SessionConfig session_config = CreateSessionConfig (settings);
413443
414444 // Session and Conversation are mutually exclusive. Only when
415445 // settings.score_target_text is set, we will create a Session to run the
0 commit comments