@@ -12,7 +12,6 @@ use std::time::Duration;
1212
1313use anyhow:: { bail, Context , Result } ;
1414use clap:: Parser ;
15- use hf_hub:: api:: sync:: ApiBuilder ;
1615
1716use llama_cpp_2:: context:: params:: { LlamaContextParams , LlamaPoolingType } ;
1817use llama_cpp_2:: context:: LlamaContext ;
@@ -92,13 +91,11 @@ fn main() -> Result<()> {
9291 . with_n_threads_batch ( std:: thread:: available_parallelism ( ) ?. get ( ) . try_into ( ) ?)
9392 . with_embeddings ( true )
9493 . with_pooling_type ( pooling_type) ;
95- println ! ( "ctx_params: {:?}" , ctx_params ) ;
94+ println ! ( "ctx_params: {ctx_params :?}" ) ;
9695 let mut ctx = model
9796 . new_context ( & backend, ctx_params)
9897 . with_context ( || "unable to create the llama_context" ) ?;
9998
100- let n_embd = model. n_embd ( ) ;
101-
10299 let prompt_lines = {
103100 let mut lines = Vec :: new ( ) ;
104101 for doc in documents {
@@ -108,13 +105,13 @@ fn main() -> Result<()> {
108105 lines
109106 } ;
110107
111- println ! ( "prompt_lines: {:?}" , prompt_lines ) ;
108+ println ! ( "prompt_lines: {prompt_lines :?}" ) ;
112109 // tokenize the prompt
113110 let tokens_lines_list = prompt_lines
114111 . iter ( )
115112 . map ( |line| model. str_to_token ( line, AddBos :: Always ) )
116113 . collect :: < Result < Vec < _ > , _ > > ( )
117- . with_context ( || format ! ( "failed to tokenize {:?}" , prompt_lines ) ) ?;
114+ . with_context ( || format ! ( "failed to tokenize {prompt_lines :?}" ) ) ?;
118115
119116 let n_ctx = ctx. n_ctx ( ) as usize ;
120117 let n_ctx_train = model. n_ctx_train ( ) ;
@@ -156,7 +153,6 @@ fn main() -> Result<()> {
156153 // } else {
157154 // tokens_lines_list.len()
158155 // };
159- let mut embeddings_stored = 0 ;
160156 let mut max_seq_id_batch = 0 ;
161157 let mut output = Vec :: with_capacity ( tokens_lines_list. len ( ) ) ;
162158
@@ -169,16 +165,10 @@ fn main() -> Result<()> {
169165 & mut ctx,
170166 & mut batch,
171167 max_seq_id_batch,
172- n_embd,
173168 & mut output,
174169 normalise,
175- pooling. clone ( ) ,
170+ & pooling,
176171 ) ?;
177- embeddings_stored += if pooling == "none" {
178- batch. n_tokens ( )
179- } else {
180- max_seq_id_batch
181- } ;
182172 max_seq_id_batch = 0 ;
183173 batch. clear ( ) ;
184174 }
@@ -191,34 +181,23 @@ fn main() -> Result<()> {
191181 & mut ctx,
192182 & mut batch,
193183 max_seq_id_batch,
194- n_embd,
195184 & mut output,
196185 normalise,
197- pooling. clone ( ) ,
186+ & pooling,
198187 ) ?;
199188
200189 let t_main_end = ggml_time_us ( ) ;
201190
202191 for ( j, embeddings) in output. iter ( ) . enumerate ( ) {
203- if pooling == "none" {
204- eprintln ! ( "embedding {j}: " ) ;
205- for i in 0 ..n_embd as usize {
206- if !normalise {
207- eprint ! ( "{:6.5} " , embeddings[ i] ) ;
208- } else {
209- eprint ! ( "{:9.6} " , embeddings[ i] ) ;
210- }
211- }
212- eprintln ! ( ) ;
213- } else if pooling == "rank" {
192+ if pooling == "rank" {
214193 eprintln ! ( "rerank score {j}: {:8.3}" , embeddings[ 0 ] ) ;
215194 } else {
216195 eprintln ! ( "embedding {j}: " ) ;
217- for i in 0 ..n_embd as usize {
218- if ! normalise {
219- eprint ! ( "{:6.5 } " , embeddings [ i ] ) ;
196+ for embedding in embeddings {
197+ if normalise {
198+ eprint ! ( "{embedding:9.6 } " ) ;
220199 } else {
221- eprint ! ( "{:9.6 } " , embeddings [ i ] ) ;
200+ eprint ! ( "{embedding:6.5 } " ) ;
222201 }
223202 }
224203 eprintln ! ( ) ;
@@ -243,10 +222,9 @@ fn batch_decode(
243222 ctx : & mut LlamaContext ,
244223 batch : & mut LlamaBatch ,
245224 s_batch : i32 ,
246- n_embd : i32 ,
247225 output : & mut Vec < Vec < f32 > > ,
248226 normalise : bool ,
249- pooling : String ,
227+ pooling : & str ,
250228) -> Result < ( ) > {
251229 eprintln ! (
252230 "{}: n_tokens = {}, n_seq = {}" ,
@@ -266,9 +244,9 @@ fn batch_decode(
266244 . with_context ( || "Failed to get sequence embeddings" ) ?;
267245 let normalized = if normalise {
268246 if pooling == "rank" {
269- normalize_embeddings ( & embeddings, -1 )
247+ normalize_embeddings ( embeddings, -1 )
270248 } else {
271- normalize_embeddings ( & embeddings, 2 )
249+ normalize_embeddings ( embeddings, 2 )
272250 }
273251 } else {
274252 embeddings. to_vec ( )
@@ -291,27 +269,30 @@ fn normalize_embeddings(input: &[f32], embd_norm: i32) -> Vec<f32> {
291269 0 => {
292270 // max absolute
293271 let max_abs = input. iter ( ) . map ( |x| x. abs ( ) ) . fold ( 0.0f32 , f32:: max) / 32760.0 ;
294- max_abs as f64
272+ f64:: from ( max_abs )
295273 }
296274 2 => {
297275 // euclidean norm
298276 input
299277 . iter ( )
300- . map ( |x| ( * x as f64 ) . powi ( 2 ) )
278+ . map ( |x| f64 :: from ( * x) . powi ( 2 ) )
301279 . sum :: < f64 > ( )
302280 . sqrt ( )
303281 }
304282 p => {
305283 // p-norm
306- let sum = input. iter ( ) . map ( |x| ( x. abs ( ) as f64 ) . powi ( p) ) . sum :: < f64 > ( ) ;
307- sum. powf ( 1.0 / p as f64 )
284+ let sum = input
285+ . iter ( )
286+ . map ( |x| f64:: from ( x. abs ( ) ) . powi ( p) )
287+ . sum :: < f64 > ( ) ;
288+ sum. powf ( 1.0 / f64:: from ( p) )
308289 }
309290 } ;
310291
311292 let norm = if sum > 0.0 { 1.0 / sum } else { 0.0 } ;
312293
313294 for i in 0 ..n {
314- output[ i] = ( input[ i] as f64 * norm) as f32 ;
295+ output[ i] = ( f64 :: from ( input[ i] ) * norm) as f32 ;
315296 }
316297
317298 output
0 commit comments