Skip to content

Commit e9e80e2

Browse files
authored
Merge pull request #65 from utilityai/8-metal-on-mac
attempt to add metal on mac
2 parents 2e05e66 + b6e0bf7 commit e9e80e2

File tree

5 files changed

+150
-19
lines changed

5 files changed

+150
-19
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama-cpp-2/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ anyhow = "1.0.80"
2626
name = "grammar_bias"
2727
harness = false
2828

29+
[[bench]]
30+
name = "generate"
31+
harness = false
32+
2933
[features]
3034
cublas = ["llama-cpp-sys-2/cublas"]
3135

llama-cpp-2/benches/generate.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use anyhow::Context;
2+
use criterion::{Criterion, criterion_group, criterion_main};
3+
use pprof::criterion::{Output, PProfProfiler};
4+
use llama_cpp_2::context::params::LlamaContextParams;
5+
use llama_cpp_2::llama_backend::LlamaBackend;
6+
use llama_cpp_2::llama_batch::LlamaBatch;
7+
use llama_cpp_2::model::{AddBos, LlamaModel};
8+
use llama_cpp_2::model::params::LlamaModelParams;
9+
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
10+
11+
fn generate(c: &mut Criterion) {
12+
let api = hf_hub::api::sync::ApiBuilder::new()
13+
.with_progress(true)
14+
.build()
15+
.unwrap();
16+
let file = api
17+
.model("TheBloke/Llama-2-7B-Chat-GGUF".to_string())
18+
.get("llama-2-7b-chat.Q4_K_M.gguf")
19+
.unwrap();
20+
let backend = LlamaBackend::init().unwrap();
21+
let model_params = LlamaModelParams::default();
22+
let model = LlamaModel::load_from_file(&backend, &file, &model_params).unwrap();
23+
let mut ctx = model
24+
.new_context(&backend, LlamaContextParams::default())
25+
.unwrap();
26+
27+
c.bench_function("generate 50 tokens", |b| {
28+
b.iter(|| {
29+
let tokens_list = model.str_to_token("Hello, my name is", AddBos::Always).unwrap();
30+
let mut n_ctx = tokens_list.len() as i32;
31+
let mut batch = LlamaBatch::new(512, 1);
32+
let last_index: i32 = (tokens_list.len() - 1) as i32;
33+
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
34+
let is_last = i == last_index;
35+
batch.add(token, i, &[0], is_last).unwrap();
36+
}
37+
ctx.decode(&mut batch).unwrap();
38+
39+
for _ in 0..50 {
40+
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
41+
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
42+
let new_token_id = ctx.sample_token_greedy(candidates_p);
43+
if new_token_id == model.token_eos() {
44+
break;
45+
}
46+
batch.clear();
47+
batch.add(new_token_id, n_ctx, &[0], true).unwrap();
48+
n_ctx += 1;
49+
ctx.decode(&mut batch).unwrap();
50+
}
51+
ctx.clear_kv_cache_seq(0, None, None)
52+
});
53+
});
54+
}
55+
56+
criterion_group!(
57+
name = benches;
58+
config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
59+
targets = generate
60+
);
61+
criterion_main!(benches);

llama-cpp-sys-2/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ cc = { workspace = true }
4343

4444
[features]
4545
cublas = []
46+

llama-cpp-sys-2/build.rs

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,23 @@ fn main() {
77

88
let cublas_enabled = env::var("CARGO_FEATURE_CUBLAS").is_ok();
99

10+
let mut ggml_cuda = if cublas_enabled { Some(cc::Build::new()) } else { None };
11+
1012
if !Path::new("llama.cpp/ggml.c").exists() {
1113
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
1214
}
1315

1416
let mut ggml = cc::Build::new();
15-
let mut ggml_cuda = if cublas_enabled {
16-
Some(cc::Build::new())
17-
} else {
18-
None
19-
};
2017
let mut llama_cpp = cc::Build::new();
2118

2219
ggml.cpp(false);
2320
llama_cpp.cpp(true);
2421

2522
// https://github.com/ggerganov/llama.cpp/blob/a836c8f534ab789b02da149fbdaf7735500bff74/Makefile#L364-L368
2623
if let Some(ggml_cuda) = &mut ggml_cuda {
27-
for lib in ["cuda", "cublas", "cudart", "cublasLt"] {
24+
for lib in [
25+
"cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt",
26+
] {
2827
println!("cargo:rustc-link-lib={}", lib);
2928
}
3029
if !ggml_cuda.get_compiler().is_like_msvc() {
@@ -39,23 +38,27 @@ fn main() {
3938
ggml_cuda
4039
.flag_if_supported("-mfp16-format=ieee")
4140
.flag_if_supported("-mno-unaligned-access");
41+
ggml.flag_if_supported("-mfp16-format=ieee")
42+
.flag_if_supported("-mno-unaligned-access");
4243
llama_cpp
4344
.flag_if_supported("-mfp16-format=ieee")
4445
.flag_if_supported("-mno-unaligned-access");
45-
ggml_cuda
46-
.flag_if_supported("-mfp16-format=ieee")
46+
ggml.flag_if_supported("-mfp16-format=ieee")
4747
.flag_if_supported("-mno-unaligned-access");
4848
}
4949

5050
ggml_cuda
5151
.cuda(true)
5252
.flag("-arch=all")
53-
.file("llama.cpp/ggml-cuda.cu");
53+
.file("llama.cpp/ggml-cuda.cu")
54+
.include("llama.cpp");
5455

5556
if ggml_cuda.get_compiler().is_like_msvc() {
5657
ggml_cuda.std("c++14");
5758
} else {
58-
ggml_cuda.std("c++17");
59+
ggml_cuda
60+
.flag("-std=c++11")
61+
.std("c++11");
5962
}
6063

6164
ggml.define("GGML_USE_CUBLAS", None);
@@ -65,22 +68,36 @@ fn main() {
6568

6669
// https://github.com/ggerganov/llama.cpp/blob/191221178f51b6e81122c5bda0fd79620e547d07/Makefile#L133-L141
6770
if cfg!(target_os = "macos") {
71+
assert!(!cublas_enabled, "CUBLAS is not supported on macOS");
72+
73+
println!("cargo:rustc-link-lib=framework=Metal");
74+
println!("cargo:rustc-link-lib=framework=Foundation");
75+
println!("cargo:rustc-link-lib=framework=MetalPerformanceShaders");
76+
println!("cargo:rustc-link-lib=framework=MetalKit");
77+
6878
llama_cpp.define("_DARWIN_C_SOURCE", None);
79+
80+
// https://github.com/ggerganov/llama.cpp/blob/3c0d25c4756742ebf15ad44700fabc0700c638bd/Makefile#L340-L343
81+
llama_cpp.define("GGML_USE_METAL", None);
82+
llama_cpp.define("GGML_USE_ACCELERATE", None);
83+
llama_cpp.define("ACCELERATE_NEW_LAPACK", None);
84+
llama_cpp.define("ACCELERATE_LAPACK_ILP64", None);
85+
println!("cargo:rustc-link-arg=framework=Accelerate");
86+
87+
metal_hack(&mut ggml);
88+
ggml.include("./llama.cpp/ggml-metal.h");
6989
}
90+
7091
if cfg!(target_os = "dragonfly") {
7192
llama_cpp.define("__BSD_VISIBLE", None);
7293
}
7394

74-
if let Some(ggml_cuda) = ggml_cuda {
75-
println!("compiling ggml-cuda");
76-
ggml_cuda.compile("ggml-cuda");
77-
}
78-
7995
if cfg!(target_os = "linux") {
8096
ggml.define("_GNU_SOURCE", None);
8197
}
8298

83-
ggml.std("c17")
99+
ggml.std("c11")
100+
.include("./llama.cpp")
84101
.file("llama.cpp/ggml.c")
85102
.file("llama.cpp/ggml-alloc.c")
86103
.file("llama.cpp/ggml-backend.c")
@@ -89,14 +106,23 @@ fn main() {
89106

90107
llama_cpp
91108
.define("_XOPEN_SOURCE", Some("600"))
92-
.std("c++17")
109+
.include("llama.cpp")
110+
.std("c++11")
93111
.file("llama.cpp/llama.cpp");
94112

113+
if let Some(ggml_cuda) = ggml_cuda {
114+
println!("compiling ggml-cuda");
115+
ggml_cuda.compile("ggml-cuda");
116+
println!("compiled ggml-cuda");
117+
}
118+
95119
println!("compiling ggml");
96120
ggml.compile("ggml");
121+
println!("compiled ggml");
97122

98123
println!("compiling llama");
99124
llama_cpp.compile("llama");
125+
println!("compiled llama");
100126

101127
let header = "llama.cpp/llama.h";
102128

@@ -116,3 +142,42 @@ fn main() {
116142
.write_to_file(out_path.join("bindings.rs"))
117143
.expect("failed to write bindings to file");
118144
}
145+
146+
// courtesy of https://github.com/rustformers/llm
147+
fn metal_hack(build: &mut cc::Build) {
148+
const GGML_METAL_METAL_PATH: &str = "llama.cpp/ggml-metal.metal";
149+
const GGML_METAL_PATH: &str = "llama.cpp/ggml-metal.m";
150+
151+
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is not defined"));
152+
153+
let ggml_metal_path = {
154+
let ggml_metal_metal = std::fs::read_to_string(GGML_METAL_METAL_PATH)
155+
.expect("Could not read ggml-metal.metal")
156+
.replace('\\', "\\\\")
157+
.replace('\n', "\\n")
158+
.replace('\r', "\\r")
159+
.replace('\"', "\\\"");
160+
161+
let ggml_metal =
162+
std::fs::read_to_string(GGML_METAL_PATH).expect("Could not read ggml-metal.m");
163+
164+
let needle = r#"NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];"#;
165+
if !ggml_metal.contains(needle) {
166+
panic!("ggml-metal.m does not contain the needle to be replaced; the patching logic needs to be reinvestigated. Contact a `llama-cpp-sys-2` developer!");
167+
}
168+
169+
// Replace the runtime read of the file with a compile-time string
170+
let ggml_metal = ggml_metal.replace(
171+
needle,
172+
&format!(r#"NSString * src = @"{ggml_metal_metal}";"#),
173+
);
174+
175+
let patched_ggml_metal_path = out_dir.join("ggml-metal.m");
176+
std::fs::write(&patched_ggml_metal_path, ggml_metal)
177+
.expect("Could not write temporary patched ggml-metal.m");
178+
179+
patched_ggml_metal_path
180+
};
181+
182+
build.file(ggml_metal_path);
183+
}

0 commit comments

Comments
 (0)