11//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2.
22#![ allow(
3- clippy:: cast_possible_wrap,
4- clippy:: cast_possible_truncation,
5- clippy:: cast_precision_loss,
6- clippy:: cast_sign_loss
3+ clippy:: cast_possible_wrap,
4+ clippy:: cast_possible_truncation,
5+ clippy:: cast_precision_loss,
6+ clippy:: cast_sign_loss
77) ]
88
99use std:: io:: Write ;
1010use std:: path:: PathBuf ;
11- use std:: str:: FromStr ;
1211use std:: time:: Duration ;
1312
1413use anyhow:: { bail, Context , Result } ;
1514use clap:: Parser ;
1615use hf_hub:: api:: sync:: ApiBuilder ;
1716
18- use llama_cpp_2:: context:: LlamaContext ;
1917use llama_cpp_2:: context:: params:: LlamaContextParams ;
18+ use llama_cpp_2:: context:: LlamaContext ;
2019use llama_cpp_2:: ggml_time_us;
2120use llama_cpp_2:: llama_backend:: LlamaBackend ;
2221use llama_cpp_2:: llama_batch:: LlamaBatch ;
22+ use llama_cpp_2:: model:: params:: LlamaModelParams ;
2323use llama_cpp_2:: model:: AddBos ;
2424use llama_cpp_2:: model:: LlamaModel ;
25- use llama_cpp_2:: model:: params:: LlamaModelParams ;
2625
2726#[ derive( clap:: Parser , Debug , Clone ) ]
2827struct Args {
@@ -41,7 +40,6 @@ struct Args {
4140 disable_gpu : bool ,
4241}
4342
44-
4543#[ derive( clap:: Subcommand , Debug , Clone ) ]
4644enum Model {
4745 /// Use an already downloaded model
@@ -119,7 +117,8 @@ fn main() -> Result<()> {
119117 let prompt_lines = prompt. lines ( ) ;
120118
121119 // tokenize the prompt
122- let tokens_lines_list = prompt_lines. map ( |line| model. str_to_token ( & line, AddBos :: Always ) )
120+ let tokens_lines_list = prompt_lines
121+ . map ( |line| model. str_to_token ( line, AddBos :: Always ) )
123122 . collect :: < Result < Vec < _ > , _ > > ( )
124123 . with_context ( || format ! ( "failed to tokenize {prompt}" ) ) ?;
125124
@@ -140,7 +139,7 @@ fn main() -> Result<()> {
140139 for token in token_line {
141140 eprintln ! ( " {} --> {}" , token, model. token_to_str( * token) ?) ;
142141 }
143- eprintln ! ( )
142+ eprintln ! ( ) ;
144143 }
145144
146145 std:: io:: stderr ( ) . flush ( ) ?;
@@ -157,15 +156,27 @@ fn main() -> Result<()> {
157156 for tokens in & tokens_lines_list {
158157 // Flush the batch if the next prompt would exceed our batch size
159158 if ( batch. n_tokens ( ) as usize + tokens. len ( ) ) > n_ctx {
160- batch_decode ( & mut ctx, & mut batch, max_seq_id_batch, & mut output, normalise) ?;
159+ batch_decode (
160+ & mut ctx,
161+ & mut batch,
162+ max_seq_id_batch,
163+ & mut output,
164+ normalise,
165+ ) ?;
161166 max_seq_id_batch = 0 ;
162167 }
163168
164- batch. add_sequence ( & tokens, max_seq_id_batch, false ) ?;
169+ batch. add_sequence ( tokens, max_seq_id_batch, false ) ?;
165170 max_seq_id_batch += 1 ;
166171 }
167172 // Handle final batch
168- batch_decode ( & mut ctx, & mut batch, max_seq_id_batch, & mut output, normalise) ?;
173+ batch_decode (
174+ & mut ctx,
175+ & mut batch,
176+ max_seq_id_batch,
177+ & mut output,
178+ normalise,
179+ ) ?;
169180
170181 let t_main_end = ggml_time_us ( ) ;
171182
@@ -175,7 +186,7 @@ fn main() -> Result<()> {
175186 }
176187
177188 let duration = Duration :: from_micros ( ( t_main_end - t_main_start) as u64 ) ;
178- let total_tokens: usize = tokens_lines_list. iter ( ) . map ( |v| v . len ( ) ) . sum ( ) ;
189+ let total_tokens: usize = tokens_lines_list. iter ( ) . map ( Vec :: len) . sum ( ) ;
179190 eprintln ! (
180191 "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n " ,
181192 total_tokens,
@@ -188,12 +199,20 @@ fn main() -> Result<()> {
188199 Ok ( ( ) )
189200}
190201
191- fn batch_decode ( ctx : & mut LlamaContext , batch : & mut LlamaBatch , s_batch : i32 , output : & mut Vec < Vec < f32 > > , normalise : bool ) -> Result < ( ) > {
202+ fn batch_decode (
203+ ctx : & mut LlamaContext ,
204+ batch : & mut LlamaBatch ,
205+ s_batch : i32 ,
206+ output : & mut Vec < Vec < f32 > > ,
207+ normalise : bool ,
208+ ) -> Result < ( ) > {
192209 ctx. clear_kv_cache ( ) ;
193210 ctx. decode ( batch) . with_context ( || "llama_decode() failed" ) ?;
194211
195212 for i in 0 ..s_batch {
196- let embedding = ctx. embeddings_seq_ith ( i) . with_context ( || "Failed to get embeddings" ) ?;
213+ let embedding = ctx
214+ . embeddings_seq_ith ( i)
215+ . with_context ( || "Failed to get embeddings" ) ?;
197216 let output_embeddings = if normalise {
198217 normalize ( embedding)
199218 } else {
@@ -209,7 +228,10 @@ fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, ou
209228}
210229
211230fn normalize ( input : & [ f32 ] ) -> Vec < f32 > {
212- let magnitude = input. iter ( ) . fold ( 0.0 , |acc, & val| val. mul_add ( val, acc) ) . sqrt ( ) ;
231+ let magnitude = input
232+ . iter ( )
233+ . fold ( 0.0 , |acc, & val| val. mul_add ( val, acc) )
234+ . sqrt ( ) ;
213235
214236 input. iter ( ) . map ( |& val| val / magnitude) . collect ( )
215237}
0 commit comments