Skip to content

Commit 49f328b

Browse files
Merge pull request #40 from menloresearch/update-dev-from-master-2025-04-02-00-08
Sync master with upstream release b5022
2 parents 18a20cf + f423981 commit 49f328b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2141
-933
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ jobs:
803803
env:
804804
OPENBLAS_VERSION: 0.3.23
805805
SDE_VERSION: 9.33.0-2024-01-07
806-
VULKAN_VERSION: 1.4.304.1
806+
VULKAN_VERSION: 1.4.309.0
807807

808808
strategy:
809809
matrix:

common/minja/minja.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ class Value : public std::enable_shared_from_this<Value> {
240240
auto index = key.get<int>();
241241
return array_->at(index < 0 ? array_->size() + index : index);
242242
} else if (object_) {
243-
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
243+
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
244244
auto it = object_->find(key.primitive_);
245245
if (it == object_->end()) return Value();
246246
return it->second;
@@ -249,7 +249,7 @@ class Value : public std::enable_shared_from_this<Value> {
249249
}
250250
void set(const Value& key, const Value& value) {
251251
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
252-
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
252+
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
253253
(*object_)[key.primitive_] = value;
254254
}
255255
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {

convert_hf_to_gguf.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3557,8 +3557,8 @@ def set_gguf_parameters(self):
35573557
head_size = hidden_size // num_attention_heads
35583558
rms_norm_eps = self.hparams["rms_norm_eps"]
35593559
intermediate_size = self.hparams["intermediate_size"]
3560-
time_mix_extra_dim = 64 if hidden_size >= 4096 else 32
3561-
time_decay_extra_dim = 128 if hidden_size >= 4096 else 64
3560+
time_mix_extra_dim = self.hparams.get("lora_rank_tokenshift", 64 if hidden_size >= 4096 else 32)
3561+
time_decay_extra_dim = self.hparams.get("lora_rank_decay", 128 if hidden_size >= 4096 else 64)
35623562

35633563
# RWKV isn't context limited
35643564
self.gguf_writer.add_context_length(1048576)
@@ -5146,10 +5146,7 @@ def set_vocab(self):
51465146
def set_gguf_parameters(self):
51475147
super().set_gguf_parameters()
51485148
hparams = self.hparams
5149-
if "head_dim" in hparams:
5150-
rope_dim = hparams["head_dim"]
5151-
else:
5152-
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
5149+
rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"]
51535150

51545151
self.gguf_writer.add_rope_dimension_count(rope_dim)
51555152
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
@@ -5175,7 +5172,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
51755172
n_head = self.hparams["num_attention_heads"]
51765173
n_kv_head = self.hparams.get("num_key_value_heads")
51775174
n_embd = self.hparams["hidden_size"]
5178-
head_dim = self.hparams.get("head_dim", n_embd // n_head)
5175+
head_dim = self.hparams.get("head_dim") or n_embd // n_head
51795176

51805177
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
51815178

docs/backend/SYCL.md

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include:
2121

2222
- **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
23-
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*.
23+
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*.
2424
- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs.
2525
- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.
2626

@@ -227,16 +227,6 @@ Upon a successful installation, SYCL is enabled for the available intel devices,
227227

228228
**oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup.
229229

230-
231-
**oneMKL for cuBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* do not contain the cuBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *cuBLAS* backend enabled is thus required to run it on Nvidia GPUs.
232-
233-
```sh
234-
git clone https://github.com/oneapi-src/oneMKL
235-
cd oneMKL
236-
cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENABLE_MKLGPU_BACKEND=OFF -DENABLE_MKLCPU_BACKEND=OFF -DENABLE_CUBLAS_BACKEND=ON -DTARGET_DOMAINS=blas
237-
cmake --build buildWithCublas --config Release
238-
```
239-
240230
**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:
241231

242232
```sh
@@ -250,16 +240,6 @@ cmake --build build-nvidia --config Release
250240

251241
**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
252242

253-
**oneMKL for rocBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* doesn't contain the rocBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *rocBLAS* backend enabled is thus required to run it on AMD GPUs.
254-
255-
```sh
256-
git clone https://github.com/oneapi-src/oneMKL
257-
cd oneMKL
258-
# Find your HIPTARGET with rocminfo, under the key 'Name:'
259-
cmake -B buildWithrocBLAS -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENABLE_MKLGPU_BACKEND=OFF -DENABLE_MKLCPU_BACKEND=OFF -DENABLE_ROCBLAS_BACKEND=ON -DHIPTARGETS=${HIPTARGET} -DTARGET_DOMAINS=blas
260-
cmake --build buildWithrocBLAS --config Release
261-
```
262-
263243
3. **Verify installation and environment**
264244

265245
In order to check the available SYCL devices on the machine, please use the `sycl-ls` command.
@@ -324,13 +304,10 @@ cmake --build build --config Release -j -v
324304

325305
#### Nvidia GPU
326306

327-
```sh
328-
# Export relevant ENV variables
329-
export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithCublas/lib:$LD_LIBRARY_PATH
330-
export LIBRARY_PATH=/path/to/oneMKL/buildWithCublas/lib:$LIBRARY_PATH
331-
export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithCublas/include:$CPLUS_INCLUDE_DIR
332-
export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR
307+
The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices.
308+
By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`.
333309

310+
```sh
334311
# Build LLAMA with Nvidia BLAS acceleration through SYCL
335312
# Setting GGML_SYCL_DEVICE_ARCH is optional but can improve performance
336313
GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
@@ -347,12 +324,10 @@ cmake --build build --config Release -j -v
347324

348325
#### AMD GPU
349326

350-
```sh
351-
# Export relevant ENV variables
352-
export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LD_LIBRARY_PATH
353-
export LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LIBRARY_PATH
354-
export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithrocBLAS/include:$CPLUS_INCLUDE_DIR
327+
The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices.
328+
By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`.
355329

330+
```sh
356331
# Build LLAMA with rocBLAS acceleration through SYCL
357332

358333
## AMD

examples/llava/clip.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,14 +1396,16 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
13961396
const int n_kv = gguf_get_n_kv(ctx);
13971397
const int ftype = get_u32(ctx, KEY_FTYPE);
13981398
const std::string ftype_str = get_ftype(ftype);
1399-
const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION);
1400-
const std::string description = gguf_get_val_str(ctx, idx_desc);
14011399
const int idx_name = gguf_find_key(ctx, KEY_NAME);
14021400
if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug
14031401
const std::string name = gguf_get_val_str(ctx, idx_name);
14041402
LOG_INF("%s: model name: %s\n", __func__, name.c_str());
14051403
}
1406-
LOG_INF("%s: description: %s\n", __func__, description.c_str());
1404+
const int idx_desc = gguf_find_key(ctx, KEY_DESCRIPTION);
1405+
if (idx_desc != -1) { // ditto
1406+
const std::string description = gguf_get_val_str(ctx, idx_desc);
1407+
LOG_INF("%s: description: %s\n", __func__, description.c_str());
1408+
}
14071409
LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx));
14081410
LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
14091411
LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors);

examples/tts/tts.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,11 +699,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
699699
const std::string voice_data = audio_data;
700700

701701
auto tmp = common_tokenize(vocab, voice_data, false, true);
702-
printf("\n\n");
702+
703+
std::ostringstream tokens_oss;
703704
for (size_t i = 0; i < tmp.size(); ++i) {
704-
printf("%d, ", tmp[i]);
705+
tokens_oss << tmp[i] << ", ";
705706
}
706-
printf("\n\n");
707+
LOG_INF("\n\n%s: llama tokens: %s\n\n", __func__, tokens_oss.str().c_str());
708+
707709
prompt_add(prompt_inp, tmp);
708710
#else
709711
prompt_add(prompt_inp, llama_tokens {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include "ggml-cuda/rope.cuh"
3232
#include "ggml-cuda/scale.cuh"
3333
#include "ggml-cuda/softmax.cuh"
34+
#include "ggml-cuda/ssm-conv.cuh"
35+
#include "ggml-cuda/ssm-scan.cuh"
3436
#include "ggml-cuda/sum.cuh"
3537
#include "ggml-cuda/sumrows.cuh"
3638
#include "ggml-cuda/tsembd.cuh"
@@ -2296,6 +2298,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22962298
case GGML_OP_SUM_ROWS:
22972299
ggml_cuda_op_sum_rows(ctx, dst);
22982300
break;
2301+
case GGML_OP_SSM_CONV:
2302+
ggml_cuda_op_ssm_conv(ctx, dst);
2303+
break;
2304+
case GGML_OP_SSM_SCAN:
2305+
ggml_cuda_op_ssm_scan(ctx, dst);
2306+
break;
22992307
case GGML_OP_ARGSORT:
23002308
ggml_cuda_op_argsort(ctx, dst);
23012309
break;
@@ -3193,6 +3201,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31933201
case GGML_OP_COS:
31943202
case GGML_OP_CLAMP:
31953203
case GGML_OP_LOG:
3204+
case GGML_OP_SSM_SCAN:
3205+
case GGML_OP_SSM_CONV:
31963206
return true;
31973207
case GGML_OP_CONT:
31983208
return op->src[0]->type != GGML_TYPE_BF16;

ggml/src/ggml-cuda/ssm-conv.cu

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#include "ssm-conv.cuh"
2+
3+
template <size_t split_d_inner, size_t d_conv>
4+
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
5+
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
6+
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
7+
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
8+
const int tid = threadIdx.x;
9+
const int bidx = blockIdx.x;
10+
const int bidy = blockIdx.y;
11+
12+
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
13+
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
14+
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
15+
16+
const int stride_x = src0_nb1 / sizeof(float);
17+
const int stride_w = src1_nb1 / sizeof(float);
18+
const int stride_y = dst_nb1 / sizeof(float);
19+
20+
float x[d_conv] = { 0.0f };
21+
float w[d_conv] = { 0.0f };
22+
23+
#pragma unroll
24+
for (int j = 0; j < d_conv; j++) {
25+
w[j] = w_block[tid * stride_w + j];
26+
}
27+
28+
for (int i = 0; i < n_t; i++) {
29+
float sumf = 0.0f;
30+
31+
if (i == 0) {
32+
for (int j = 0; j < d_conv; j++) {
33+
x[j] = x_block[tid * stride_x + j];
34+
}
35+
} else {
36+
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
37+
}
38+
39+
#pragma unroll
40+
for (int j = 0; j < d_conv; j++) {
41+
sumf += x[(i + j) % d_conv] * w[j];
42+
}
43+
y_block[i * stride_y + tid] = sumf;
44+
}
45+
}
46+
47+
template <size_t split_d_inner, size_t d_conv, size_t split_n_t>
48+
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
49+
const int src0_nb0, const int src0_nb1, const int src0_nb2,
50+
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
51+
const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
52+
const int nr, const int n_t, const int n_s) {
53+
const int tid = threadIdx.x;
54+
const int bidx = blockIdx.x;
55+
const int bidy = blockIdx.y;
56+
const int bidz = blockIdx.z;
57+
58+
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
59+
bidz * split_n_t * src0_nb0);
60+
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
61+
float * y_block =
62+
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
63+
64+
const int stride_x = src0_nb1 / sizeof(float);
65+
const int stride_w = src1_nb1 / sizeof(float);
66+
const int stride_y = dst_nb1 / sizeof(float);
67+
68+
float x[d_conv] = { 0.0f };
69+
float w[d_conv] = { 0.0f };
70+
71+
#pragma unroll
72+
for (int j = 0; j < d_conv; j++) {
73+
w[j] = w_block[tid * stride_w + j];
74+
}
75+
76+
#pragma unroll
77+
for (int i = 0; i < split_n_t; i++) {
78+
if (bidz * split_n_t + i < n_t) {
79+
float sumf = 0.0f;
80+
81+
if (i == 0) {
82+
for (int j = 0; j < d_conv; j++) {
83+
x[j] = x_block[tid * stride_x + j];
84+
}
85+
} else {
86+
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
87+
}
88+
89+
#pragma unroll
90+
for (int j = 0; j < d_conv; j++) {
91+
sumf += x[(i + j) % d_conv] * w[j];
92+
}
93+
y_block[i * stride_y + tid] = sumf;
94+
}
95+
}
96+
}
97+
98+
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
99+
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
100+
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
101+
const int n_s, cudaStream_t stream) {
102+
const int threads = 128;
103+
GGML_ASSERT(nr % threads == 0);
104+
105+
if (n_t <= 32) {
106+
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
107+
if (nc == 4) {
108+
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109+
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t,
110+
n_s);
111+
} else {
112+
GGML_ABORT("Only support kernel size = 4 now.");
113+
}
114+
} else {
115+
if (nc == 4) {
116+
const int split_n_t = 32;
117+
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
118+
ssm_conv_long_token_f32<threads, 4, split_n_t>
119+
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
120+
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
121+
} else {
122+
GGML_ABORT("Only support kernel size = 4 right now.");
123+
}
124+
}
125+
}
126+
127+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
128+
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
129+
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
130+
131+
const int nc = src1->ne[0]; // d_conv
132+
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
133+
const int nr = src0->ne[1]; // d_inner
134+
const int n_t = dst->ne[1]; // tokens per sequence
135+
const int n_s = dst->ne[2]; // number of sequences in the batch
136+
137+
GGML_ASSERT(dst->ne[0] == nr);
138+
GGML_ASSERT(src0->nb[0] == sizeof(float));
139+
GGML_ASSERT(src1->nb[0] == sizeof(float));
140+
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
141+
142+
const float * src0_d = (const float *) src0->data;
143+
const float * src1_d = (const float *) src1->data;
144+
float * dst_d = (float *) dst->data;
145+
cudaStream_t stream = ctx.stream();
146+
147+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
148+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
149+
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
150+
dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
151+
}

ggml/src/ggml-cuda/ssm-conv.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)