Skip to content

Commit b2284b7

Browse files
committed
rustfmt
1 parent 71afef1 commit b2284b7

File tree

7 files changed

+82
-40
lines changed

7 files changed

+82
-40
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama-cpp-2/benches/generate.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use anyhow::Context;
2-
use criterion::{Criterion, criterion_group, criterion_main};
3-
use pprof::criterion::{Output, PProfProfiler};
2+
use criterion::{criterion_group, criterion_main, Criterion};
43
use llama_cpp_2::context::params::LlamaContextParams;
54
use llama_cpp_2::llama_backend::LlamaBackend;
65
use llama_cpp_2::llama_batch::LlamaBatch;
7-
use llama_cpp_2::model::{AddBos, LlamaModel};
86
use llama_cpp_2::model::params::LlamaModelParams;
7+
use llama_cpp_2::model::{AddBos, LlamaModel};
98
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
9+
use pprof::criterion::{Output, PProfProfiler};
1010

1111
fn generate(c: &mut Criterion) {
1212
let api = hf_hub::api::sync::ApiBuilder::new()
@@ -26,7 +26,9 @@ fn generate(c: &mut Criterion) {
2626

2727
c.bench_function("generate 50 tokens", |b| {
2828
b.iter(|| {
29-
let tokens_list = model.str_to_token("Hello, my name is", AddBos::Always).unwrap();
29+
let tokens_list = model
30+
.str_to_token("Hello, my name is", AddBos::Always)
31+
.unwrap();
3032
let mut n_ctx = tokens_list.len() as i32;
3133
let mut batch = LlamaBatch::new(512, 1);
3234
let last_index: i32 = (tokens_list.len() - 1) as i32;
@@ -58,4 +60,4 @@ criterion_group!(
5860
config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
5961
targets = generate
6062
);
61-
criterion_main!(benches);
63+
criterion_main!(benches);

llama-cpp-2/src/context/params.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ impl LlamaContextParams {
8484
/// let params = params.with_seed(1234);
8585
/// assert_eq!(params.seed(), 1234);
8686
/// ```
87-
#[must_use] pub fn with_seed(mut self, seed: u32) -> Self {
87+
#[must_use]
88+
pub fn with_seed(mut self, seed: u32) -> Self {
8889
self.context_params.seed = seed;
8990
self
9091
}
@@ -99,7 +100,8 @@ impl LlamaContextParams {
99100
/// .with_seed(1234);
100101
/// assert_eq!(params.seed(), 1234);
101102
/// ```
102-
#[must_use] pub fn seed(&self) -> u32 {
103+
#[must_use]
104+
pub fn seed(&self) -> u32 {
103105
self.context_params.seed
104106
}
105107

@@ -114,7 +116,8 @@ impl LlamaContextParams {
114116
/// let params = params.with_n_ctx(NonZeroU32::new(2048));
115117
/// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
116118
/// ```
117-
#[must_use] pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
119+
#[must_use]
120+
pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
118121
self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
119122
self
120123
}
@@ -128,7 +131,8 @@ impl LlamaContextParams {
128131
/// ```rust
129132
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
130133
/// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
131-
#[must_use] pub fn n_ctx(&self) -> Option<NonZeroU32> {
134+
#[must_use]
135+
pub fn n_ctx(&self) -> Option<NonZeroU32> {
132136
NonZeroU32::new(self.context_params.n_ctx)
133137
}
134138

@@ -143,7 +147,8 @@ impl LlamaContextParams {
143147
/// .with_n_batch(2048);
144148
/// assert_eq!(params.n_batch(), 2048);
145149
/// ```
146-
#[must_use] pub fn with_n_batch(mut self, n_batch: u32) -> Self {
150+
#[must_use]
151+
pub fn with_n_batch(mut self, n_batch: u32) -> Self {
147152
self.context_params.n_batch = n_batch;
148153
self
149154
}
@@ -157,7 +162,8 @@ impl LlamaContextParams {
157162
/// let params = LlamaContextParams::default();
158163
/// assert_eq!(params.n_batch(), 512);
159164
/// ```
160-
#[must_use] pub fn n_batch(&self) -> u32 {
165+
#[must_use]
166+
pub fn n_batch(&self) -> u32 {
161167
self.context_params.n_batch
162168
}
163169

@@ -171,7 +177,8 @@ impl LlamaContextParams {
171177
/// .with_rope_scaling_type(RopeScalingType::Linear);
172178
/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
173179
/// ```
174-
#[must_use] pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
180+
#[must_use]
181+
pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
175182
self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
176183
self
177184
}
@@ -184,7 +191,8 @@ impl LlamaContextParams {
184191
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
185192
/// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
186193
/// ```
187-
#[must_use] pub fn rope_scaling_type(&self) -> RopeScalingType {
194+
#[must_use]
195+
pub fn rope_scaling_type(&self) -> RopeScalingType {
188196
RopeScalingType::from(self.context_params.rope_scaling_type)
189197
}
190198

@@ -198,7 +206,8 @@ impl LlamaContextParams {
198206
/// .with_rope_freq_base(0.5);
199207
/// assert_eq!(params.rope_freq_base(), 0.5);
200208
/// ```
201-
#[must_use] pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
209+
#[must_use]
210+
pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
202211
self.context_params.rope_freq_base = rope_freq_base;
203212
self
204213
}
@@ -211,7 +220,8 @@ impl LlamaContextParams {
211220
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
212221
/// assert_eq!(params.rope_freq_base(), 0.0);
213222
/// ```
214-
#[must_use] pub fn rope_freq_base(&self) -> f32 {
223+
#[must_use]
224+
pub fn rope_freq_base(&self) -> f32 {
215225
self.context_params.rope_freq_base
216226
}
217227

@@ -225,7 +235,8 @@ impl LlamaContextParams {
225235
/// .with_rope_freq_scale(0.5);
226236
/// assert_eq!(params.rope_freq_scale(), 0.5);
227237
/// ```
228-
#[must_use] pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
238+
#[must_use]
239+
pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
229240
self.context_params.rope_freq_scale = rope_freq_scale;
230241
self
231242
}
@@ -238,7 +249,8 @@ impl LlamaContextParams {
238249
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
239250
/// assert_eq!(params.rope_freq_scale(), 0.0);
240251
/// ```
241-
#[must_use] pub fn rope_freq_scale(&self) -> f32 {
252+
#[must_use]
253+
pub fn rope_freq_scale(&self) -> f32 {
242254
self.context_params.rope_freq_scale
243255
}
244256

@@ -250,7 +262,8 @@ impl LlamaContextParams {
250262
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
251263
/// assert_eq!(params.n_threads(), 4);
252264
/// ```
253-
#[must_use] pub fn n_threads(&self) -> u32 {
265+
#[must_use]
266+
pub fn n_threads(&self) -> u32 {
254267
self.context_params.n_threads
255268
}
256269

@@ -264,7 +277,8 @@ impl LlamaContextParams {
264277
/// .with_n_threads(8);
265278
/// assert_eq!(params.n_threads(), 8);
266279
/// ```
267-
#[must_use] pub fn with_n_threads(mut self, n_threads: u32) -> Self {
280+
#[must_use]
281+
pub fn with_n_threads(mut self, n_threads: u32) -> Self {
268282
self.context_params.n_threads = n_threads;
269283
self
270284
}

llama-cpp-2/src/context/session.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
//! utilities for working with session files
22
3-
use std::ffi::{CString, NulError};
4-
use std::path::{Path, PathBuf};
53
use crate::context::LlamaContext;
64
use crate::token::LlamaToken;
5+
use std::ffi::{CString, NulError};
6+
use std::path::{Path, PathBuf};
77

88
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
99
pub enum SaveSessionError {
@@ -36,11 +36,15 @@ impl LlamaContext<'_> {
3636
///
3737
/// * `path_session` - The file to save to.
3838
/// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled.
39-
pub fn save_session_file(&self, path_session: impl AsRef<Path>, tokens: &[LlamaToken]) -> Result<(), SaveSessionError> {
39+
pub fn save_session_file(
40+
&self,
41+
path_session: impl AsRef<Path>,
42+
tokens: &[LlamaToken],
43+
) -> Result<(), SaveSessionError> {
4044
let path = path_session.as_ref();
4145
let path = path
4246
.to_str()
43-
.ok_or(SaveSessionError::PathToStrError(path.to_path_buf()))?;
47+
.ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
4448

4549
let cstr = CString::new(path)?;
4650

@@ -49,7 +53,8 @@ impl LlamaContext<'_> {
4953
self.context.as_ptr(),
5054
cstr.as_ptr(),
5155
tokens.as_ptr() as *const i32,
52-
tokens.len())
56+
tokens.len(),
57+
)
5358
} {
5459
Ok(())
5560
} else {
@@ -64,7 +69,11 @@ impl LlamaContext<'_> {
6469
///
6570
/// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error.
6671
/// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error.
67-
pub fn load_session_file(&mut self, path_session: impl AsRef<Path>, max_tokens: usize) -> Result<Vec<LlamaToken>, LoadSessionError> {
72+
pub fn load_session_file(
73+
&mut self,
74+
path_session: impl AsRef<Path>,
75+
max_tokens: usize,
76+
) -> Result<Vec<LlamaToken>, LoadSessionError> {
6877
let path = path_session.as_ref();
6978
let path = path
7079
.to_str()
@@ -80,12 +89,14 @@ impl LlamaContext<'_> {
8089
cstr.as_ptr(),
8190
tokens.as_mut_ptr() as *mut i32,
8291
max_tokens,
83-
&mut n_out) {
92+
&mut n_out,
93+
) {
94+
assert!(n_out <= max_tokens, "n_out is greater than max_tokens");
8495
tokens.set_len(n_out);
8596
Ok(tokens)
8697
} else {
8798
Err(LoadSessionError::FailedToLoad)
8899
}
89100
}
90101
}
91-
}
102+
}

llama-cpp-2/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ pub fn ggml_time_us() -> i64 {
196196
}
197197

198198
/// checks if mlock is supported
199-
///
199+
///
200200
/// ```
201201
/// # use llama_cpp_2::llama_supports_mlock;
202202
///

llama-cpp-sys-2/build.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ fn main() {
77

88
let cublas_enabled = env::var("CARGO_FEATURE_CUBLAS").is_ok();
99

10-
let mut ggml_cuda = if cublas_enabled { Some(cc::Build::new()) } else { None };
10+
let mut ggml_cuda = if cublas_enabled {
11+
Some(cc::Build::new())
12+
} else {
13+
None
14+
};
1115

1216
if !Path::new("llama.cpp/ggml.c").exists() {
1317
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
@@ -56,9 +60,7 @@ fn main() {
5660
if ggml_cuda.get_compiler().is_like_msvc() {
5761
ggml_cuda.std("c++14");
5862
} else {
59-
ggml_cuda
60-
.flag("-std=c++11")
61-
.std("c++11");
63+
ggml_cuda.flag("-std=c++11").std("c++11");
6264
}
6365

6466
ggml.define("GGML_USE_CUBLAS", None);

simple/src/main.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
//! This is a translation of simple.cpp in llama.cpp using llama-cpp-2.
2-
#![allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation, clippy::cast_precision_loss, clippy::cast_sign_loss)]
2+
#![allow(
3+
clippy::cast_possible_wrap,
4+
clippy::cast_possible_truncation,
5+
clippy::cast_precision_loss,
6+
clippy::cast_sign_loss
7+
)]
38

49
use anyhow::{bail, Context, Result};
510
use clap::Parser;
11+
use hf_hub::api::sync::ApiBuilder;
612
use llama_cpp_2::context::params::LlamaContextParams;
713
use llama_cpp_2::ggml_time_us;
814
use llama_cpp_2::llama_backend::LlamaBackend;
@@ -15,7 +21,6 @@ use std::io::Write;
1521
use std::num::NonZeroU32;
1622
use std::path::PathBuf;
1723
use std::time::Duration;
18-
use hf_hub::api::sync::ApiBuilder;
1924

2025
#[derive(clap::Parser, Debug, Clone)]
2126
struct Args {
@@ -62,13 +67,19 @@ impl Model {
6267
.with_context(|| "unable to create huggingface api")?
6368
.model(repo)
6469
.get(&model)
65-
.with_context(|| "unable to download model")
70+
.with_context(|| "unable to download model"),
6671
}
6772
}
6873
}
6974

7075
fn main() -> Result<()> {
71-
let Args { n_len, model, prompt, #[cfg(feature = "cublas")] disable_gpu } = Args::parse();
76+
let Args {
77+
n_len,
78+
model,
79+
prompt,
80+
#[cfg(feature = "cublas")]
81+
disable_gpu,
82+
} = Args::parse();
7283

7384
// init LLM
7485
let backend = LlamaBackend::init()?;
@@ -84,8 +95,10 @@ fn main() -> Result<()> {
8495
#[cfg(not(feature = "cublas"))]
8596
LlamaModelParams::default()
8697
};
87-
88-
let model_path = model.to_path().with_context(|| "failed to get model from args")?;
98+
99+
let model_path = model
100+
.to_path()
101+
.with_context(|| "failed to get model from args")?;
89102

90103
let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
91104
.with_context(|| "unable to load model")?;

0 commit comments

Comments
 (0)