Skip to content

Commit 548f036

Browse files
authored
Merge pull request #137 from utilityai/sample-rep
added `sample_repetition_penalty`
2 parents 14b8187 + 4c2cd79 commit 548f036

File tree

12 files changed

+228
-69
lines changed

12 files changed

+228
-69
lines changed

embeddings/src/main.rs

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2.
22
#![allow(
3-
clippy::cast_possible_wrap,
4-
clippy::cast_possible_truncation,
5-
clippy::cast_precision_loss,
6-
clippy::cast_sign_loss
3+
clippy::cast_possible_wrap,
4+
clippy::cast_possible_truncation,
5+
clippy::cast_precision_loss,
6+
clippy::cast_sign_loss
77
)]
88

99
use std::io::Write;
1010
use std::path::PathBuf;
11-
use std::str::FromStr;
1211
use std::time::Duration;
1312

1413
use anyhow::{bail, Context, Result};
1514
use clap::Parser;
1615
use hf_hub::api::sync::ApiBuilder;
1716

18-
use llama_cpp_2::context::LlamaContext;
1917
use llama_cpp_2::context::params::LlamaContextParams;
18+
use llama_cpp_2::context::LlamaContext;
2019
use llama_cpp_2::ggml_time_us;
2120
use llama_cpp_2::llama_backend::LlamaBackend;
2221
use llama_cpp_2::llama_batch::LlamaBatch;
22+
use llama_cpp_2::model::params::LlamaModelParams;
2323
use llama_cpp_2::model::AddBos;
2424
use llama_cpp_2::model::LlamaModel;
25-
use llama_cpp_2::model::params::LlamaModelParams;
2625

2726
#[derive(clap::Parser, Debug, Clone)]
2827
struct Args {
@@ -41,7 +40,6 @@ struct Args {
4140
disable_gpu: bool,
4241
}
4342

44-
4543
#[derive(clap::Subcommand, Debug, Clone)]
4644
enum Model {
4745
/// Use an already downloaded model
@@ -119,7 +117,8 @@ fn main() -> Result<()> {
119117
let prompt_lines = prompt.lines();
120118

121119
// tokenize the prompt
122-
let tokens_lines_list = prompt_lines.map(|line| model.str_to_token(&line, AddBos::Always))
120+
let tokens_lines_list = prompt_lines
121+
.map(|line| model.str_to_token(line, AddBos::Always))
123122
.collect::<Result<Vec<_>, _>>()
124123
.with_context(|| format!("failed to tokenize {prompt}"))?;
125124

@@ -140,7 +139,7 @@ fn main() -> Result<()> {
140139
for token in token_line {
141140
eprintln!(" {} --> {}", token, model.token_to_str(*token)?);
142141
}
143-
eprintln!()
142+
eprintln!();
144143
}
145144

146145
std::io::stderr().flush()?;
@@ -157,15 +156,27 @@ fn main() -> Result<()> {
157156
for tokens in &tokens_lines_list {
158157
// Flush the batch if the next prompt would exceed our batch size
159158
if (batch.n_tokens() as usize + tokens.len()) > n_ctx {
160-
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;
159+
batch_decode(
160+
&mut ctx,
161+
&mut batch,
162+
max_seq_id_batch,
163+
&mut output,
164+
normalise,
165+
)?;
161166
max_seq_id_batch = 0;
162167
}
163168

164-
batch.add_sequence(&tokens, max_seq_id_batch, false)?;
169+
batch.add_sequence(tokens, max_seq_id_batch, false)?;
165170
max_seq_id_batch += 1;
166171
}
167172
// Handle final batch
168-
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;
173+
batch_decode(
174+
&mut ctx,
175+
&mut batch,
176+
max_seq_id_batch,
177+
&mut output,
178+
normalise,
179+
)?;
169180

170181
let t_main_end = ggml_time_us();
171182

@@ -175,7 +186,7 @@ fn main() -> Result<()> {
175186
}
176187

177188
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
178-
let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum();
189+
let total_tokens: usize = tokens_lines_list.iter().map(Vec::len).sum();
179190
eprintln!(
180191
"Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n",
181192
total_tokens,
@@ -188,12 +199,20 @@ fn main() -> Result<()> {
188199
Ok(())
189200
}
190201

191-
fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec<Vec<f32>>, normalise: bool) -> Result<()> {
202+
fn batch_decode(
203+
ctx: &mut LlamaContext,
204+
batch: &mut LlamaBatch,
205+
s_batch: i32,
206+
output: &mut Vec<Vec<f32>>,
207+
normalise: bool,
208+
) -> Result<()> {
192209
ctx.clear_kv_cache();
193210
ctx.decode(batch).with_context(|| "llama_decode() failed")?;
194211

195212
for i in 0..s_batch {
196-
let embedding = ctx.embeddings_seq_ith(i).with_context(|| "Failed to get embeddings")?;
213+
let embedding = ctx
214+
.embeddings_seq_ith(i)
215+
.with_context(|| "Failed to get embeddings")?;
197216
let output_embeddings = if normalise {
198217
normalize(embedding)
199218
} else {
@@ -209,7 +228,10 @@ fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, ou
209228
}
210229

211230
fn normalize(input: &[f32]) -> Vec<f32> {
212-
let magnitude = input.iter().fold(0.0, |acc, &val| val.mul_add(val, acc)).sqrt();
231+
let magnitude = input
232+
.iter()
233+
.fold(0.0, |acc, &val| val.mul_add(val, acc))
234+
.sqrt();
213235

214236
input.iter().map(|&val| val / magnitude).collect()
215237
}

llama-cpp-2/src/context.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,26 @@ impl<'model> LlamaContext<'model> {
9595
/// - When the current context was constructed without enabling embeddings.
9696
/// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`]
9797
/// - If the given sequence index exceeds the max sequence id.
98+
///
99+
/// # Panics
100+
///
101+
/// * `n_embd` does not fit into a usize
98102
pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
99103
if !self.embeddings_enabled {
100104
return Err(EmbeddingsError::NotEnabled);
101105
}
102106

107+
let n_embd =
108+
usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
109+
103110
unsafe {
104111
let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i);
105112

106113
// Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
107114
if embedding.is_null() {
108115
Err(EmbeddingsError::NonePoolType)
109116
} else {
110-
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
117+
Ok(slice::from_raw_parts(embedding, n_embd))
111118
}
112119
}
113120
}
@@ -124,18 +131,25 @@ impl<'model> LlamaContext<'model> {
124131
/// - When the current context was constructed without enabling embeddings.
125132
/// - When the given token didn't have logits enabled when it was passed.
126133
/// - If the given token index exceeds the max token id.
134+
///
135+
/// # Panics
136+
///
137+
/// * `n_embd` does not fit into a usize
127138
pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
128139
if !self.embeddings_enabled {
129140
return Err(EmbeddingsError::NotEnabled);
130141
}
131142

143+
let n_embd =
144+
usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
145+
132146
unsafe {
133147
let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i);
134148
// Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
135149
if embedding.is_null() {
136150
Err(EmbeddingsError::LogitsNotEnabled)
137151
} else {
138-
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
152+
Ok(slice::from_raw_parts(embedding, n_embd))
139153
}
140154
}
141155
}

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ impl LlamaContext<'_> {
6767
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_keep(self.context.as_ptr(), seq_id) }
6868
}
6969

70+
#[allow(clippy::doc_markdown)]
7071
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
7172
/// If the KV cache is RoPEd, the KV data is updated accordingly:
7273
/// - lazily on next [`LlamaContext::decode`]
@@ -212,9 +213,9 @@ impl<'a> KVCacheView<'a> {
212213
}
213214

214215
/// Information for individual cells.
215-
///
216+
///
216217
/// # Panics
217-
///
218+
///
218219
/// - if `n_cells` does not fit into usize.
219220
pub fn cells(&self) -> impl Iterator<Item = KVCacheViewCell> {
220221
unsafe {
@@ -228,9 +229,9 @@ impl<'a> KVCacheView<'a> {
228229
}
229230

230231
/// The sequences for each cell. There will be `n_max_seq` items per cell.
231-
///
232+
///
232233
/// # Panics
233-
///
234+
///
234235
/// - if `n_cells * n_max_seq` does not fit into usize.
235236
/// - if `n_max_seq` does not fit into usize.
236237
pub fn cells_sequences(&self) -> impl Iterator<Item = &[llama_cpp_sys_2::llama_seq_id]> {

llama-cpp-2/src/context/sample.rs

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,24 @@ use crate::context::LlamaContext;
44
use crate::grammar::LlamaGrammar;
55
use crate::token::data_array::LlamaTokenDataArray;
66
use crate::token::LlamaToken;
7-
use llama_cpp_sys_2::llama_context;
87

98
/// struct to hold params for sampling
109
#[derive(Debug)]
11-
#[deprecated(since = "0.1.32", note = "this does not scale well with many params and does not allow for changing of orders.")]
10+
#[deprecated(
11+
since = "0.1.32",
12+
note = "this does not scale well with many params and does not allow for changing of orders."
13+
)]
1214
pub struct Sampler<'grammar> {
1315
token_data_array: LlamaTokenDataArray,
1416
grammar: Option<&'grammar mut LlamaGrammar>,
1517
temperature: Option<f32>,
1618
}
1719

1820
impl<'grammar> Sampler<'grammar> {
19-
#[deprecated(since = "0.1.32", note = "this does not scale well with many params and does not allow for changing of orders.")]
21+
#[deprecated(
22+
since = "0.1.32",
23+
note = "this does not scale well with many params and does not allow for changing of orders."
24+
)]
2025
fn sample(self, llama_context: &mut LlamaContext) -> LlamaToken {
2126
match self {
2227
Sampler {
@@ -60,7 +65,10 @@ impl<'grammar> Sampler<'grammar> {
6065

6166
/// Create a new sampler.
6267
#[must_use]
63-
#[deprecated(since = "0.1.32", note = "this does not scale well with many params and does not allow for changing of orders.")]
68+
#[deprecated(
69+
since = "0.1.32",
70+
note = "this does not scale well with many params and does not allow for changing of orders."
71+
)]
6472
pub fn new(llama_token_data_array: LlamaTokenDataArray) -> Self {
6573
Self {
6674
token_data_array: llama_token_data_array,
@@ -71,7 +79,10 @@ impl<'grammar> Sampler<'grammar> {
7179

7280
/// Set the grammar for sampling.
7381
#[must_use]
74-
#[deprecated(since = "0.1.32", note = "this does not scale well with many params and does not allow for changing of orders.")]
82+
#[deprecated(
83+
since = "0.1.32",
84+
note = "this does not scale well with many params and does not allow for changing of orders."
85+
)]
7586
pub fn with_grammar(mut self, grammar: &'grammar mut LlamaGrammar) -> Self {
7687
self.grammar = Some(grammar);
7788
self
@@ -91,7 +102,10 @@ impl<'grammar> Sampler<'grammar> {
91102
/// .with_temperature(0.5);
92103
/// ```
93104
#[must_use]
94-
#[deprecated(since = "0.1.32", note = "this does not scale well with many params and does not allow for changing of orders.")]
105+
#[deprecated(
106+
since = "0.1.32",
107+
note = "this does not scale well with many params and does not allow for changing of orders."
108+
)]
95109
pub fn with_temperature(mut self, temperature: f32) -> Self {
96110
if temperature == 0.0 {
97111
return self;
@@ -107,7 +121,10 @@ impl LlamaContext<'_> {
107121
/// # Panics
108122
///
109123
/// - sampler contains no tokens
110-
#[deprecated(since = "0.1.32", note = "this does not scale well with many params and does not allow for changing of orders.")]
124+
#[deprecated(
125+
since = "0.1.32",
126+
note = "this does not scale well with many params and does not allow for changing of orders."
127+
)]
111128
pub fn sample(&mut self, sampler: Sampler) -> LlamaToken {
112129
sampler.sample(self)
113130
}
@@ -157,7 +174,7 @@ impl LlamaContext<'_> {
157174
if temperature == 0.0 {
158175
return;
159176
}
160-
let ctx: *mut llama_context = self.context.as_ptr();
177+
let ctx: *mut llama_cpp_sys_2::llama_context = self.context.as_ptr();
161178
unsafe {
162179
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
163180
llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature);
@@ -254,4 +271,24 @@ impl LlamaContext<'_> {
254271
});
255272
}
256273
}
274+
275+
/// See [`LlamaTokenDataArray::sample_repetition_penalty`]
276+
pub fn sample_repetition_penalty(
277+
&mut self,
278+
token_data: &mut LlamaTokenDataArray,
279+
last_tokens: &[LlamaToken],
280+
penalty_last_n: usize,
281+
penalty_repeat: f32,
282+
penalty_freq: f32,
283+
penalty_present: f32,
284+
) {
285+
token_data.sample_repetition_penalty(
286+
Some(self),
287+
last_tokens,
288+
penalty_last_n,
289+
penalty_repeat,
290+
penalty_freq,
291+
penalty_present,
292+
);
293+
}
257294
}

llama-cpp-2/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub enum LLamaCppError {
5252
/// There was an error adding a token to a batch.
5353
#[error["{0}"]]
5454
BatchAddError(#[from] BatchAddError),
55+
/// see [`EmbeddingsError`]
5556
#[error(transparent)]
5657
EmbeddingError(#[from] EmbeddingsError),
5758
}
@@ -81,10 +82,13 @@ pub enum DecodeError {
8182
/// When embedding related functions fail
8283
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
8384
pub enum EmbeddingsError {
85+
/// Embeddings weren't enabled in the context options
8486
#[error("Embeddings weren't enabled in the context options")]
8587
NotEnabled,
88+
/// Logits weren't enabled for the given token
8689
#[error("Logits were not enabled for the given token")]
8790
LogitsNotEnabled,
91+
/// The given sequence index exceeds the max sequence id
8892
#[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")]
8993
NonePoolType,
9094
}

llama-cpp-2/src/llama_backend.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
//! Representation of an initialized llama backend
22
33
use crate::LLamaCppError;
4+
use llama_cpp_sys_2::ggml_log_level;
45
use std::sync::atomic::AtomicBool;
56
use std::sync::atomic::Ordering::SeqCst;
6-
use llama_cpp_sys_2::ggml_log_level;
77

88
/// Representation of an initialized llama backend
99
/// This is required as a parameter for most llama functions as the backend must be initialized
@@ -76,10 +76,11 @@ impl LlamaBackend {
7676
_level: ggml_log_level,
7777
_text: *const ::std::os::raw::c_char,
7878
_user_data: *mut ::std::os::raw::c_void,
79-
) {}
79+
) {
80+
}
8081

8182
unsafe {
82-
llama_cpp_sys_2::llama_log_set(Some(void_log), std::ptr::null_mut())
83+
llama_cpp_sys_2::llama_log_set(Some(void_log), std::ptr::null_mut());
8384
}
8485
}
8586
}

0 commit comments

Comments
 (0)