Skip to content

Conversation

Thireus
Copy link
Contributor

@Thireus Thireus commented Aug 1, 2025

Port from llama.cpp - See ggml-org/llama.cpp#14939
Ported from https://github.com/sammcj/llama.cpp/tree/dcbbd2cb057a6c6e907e0195395a74201ef19e1b
Old conversation - #662

Still some cleanup work to do, such as the GGML_ASSERT which needs to be restored. Contributors are welcomed.

Want to test it out?

# Clone patched llama.cpp for glm-4.5
git clone https://github.com/ikawrakow/ik_llama.cpp
cd ik_llama.cpp
git remote add Thireus https://github.com/Thireus/ik_llama.cpp.git
git fetch Thireus
git checkout glm-4.5-clean
git pull
git rev-parse --short HEAD

Instructions to convert HF to GGUF:

→ Alternatively you can download GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT - must be used with ulimit -n 9999

# Install dependencies
apt-get install python3-dev python3-pip python3-venv python3-wheel python3-setuptools git cmake pipx ccache
apt-get install --no-install-recommends zlib1g-dev libxml2-dev libssl-dev libgmp-dev libmpfr-dev

# Install uv via pipx
pipx install uv

# Prepare env
mkdir -p ~/AI/hf-to-bf16
cd ~/AI/hf-to-bf16
uv venv ./venv --python 3.12 --python-preference=only-managed

# Activate env
source venv/bin/activate

# Clone patched ik_llama.cpp for glm-4.5, if not already available
git clone https://github.com/ikawrakow/ik_llama.cpp
cd ik_llama.cpp
git remote add Thireus https://github.com/Thireus/ik_llama.cpp.git
git fetch Thireus
git checkout glm-4.5-clean
git pull
git rev-parse --short HEAD

# Install ik_llama.cpp dependencies
cd ~/AI/hf-to-bf16/ik_llama.cpp
uv pip install -r requirements/requirements-convert_hf_to_gguf.txt --prerelease=allow --index-strategy unsafe-best-match
# Build ik_llama.cpp (optional)
cmake -B build -DGGML_AVX=ON -DGGML_AVX2=ON -DLLAMA_CURL=OFF
cmake --build build --config Release -j16
cd ..

# Build triton-cpu
cd ~/AI/hf-to-bf16/
git clone https://github.com/triton-lang/triton-cpu --recursive
cd triton-cpu
uv pip install ninja cmake wheel setuptools pybind11

# Apply this patch - https://github.com/ikawrakow/ik_llama.cpp/issues/383#issuecomment-2865306085
nano -w CMakeLists.txt
---
#  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-covered-switch-default -fvisibility=hidden")
---
nano -w third_party/cpu/CMakeLists.txt
---
#find_package(dnnl CONFIG)
#if (dnnl_FOUND)
#... comment all the lines up until the endif
#endif()
---

# Install dependencies
uv pip install -r python/requirements.txt

# Compile
MAX_JOBS=16 uv pip install -e python --no-build-isolation

# Be patient, "Preparing Packages" downloads a lot of stuff before build begins...

# Convert HF to BF16 for GLM-4.5
# It is assumed that GLM-4.5 was downloaded from https://huggingface.co/zai-org/GLM-4.5 in ~/AI/huggingface/GLM-4.5
cd ~/AI/hf-to-bf16/ik_llama.cpp
mkdir -p ~/AI/GLM-4.5/GLM-4_5-BF16
python convert_hf_to_gguf.py \
     --outtype bf16 \
     --outfile ~/AI/GLM-4.5/GLM-4_5-BF16/GLM-4_5-BF16\
     --no-tensor-first-split \
     ~/AI/huggingface/GLM-4.5

Basic instructions to quantize from BF16 to q4_K quant:

→ Alternatively you can download GLM-4.5-THIREUS-Q4_K-SPECIAL_SPLIT - must be used with ulimit -n 9999 and ik_llama built with -DGGML_MAX_CONTEXTS=2048
IMPORTANT: This is not how these SPECIAL_SPLIT quants are supposed to be used, as they must be used in a recipe (read more at https://gguf.thireus.com), so expect quite bad perplexity. However still usable this way for testing purpose.

# Example with q4_K
mkdir ~/AI/GLM-4.5/GLM-4_5-Q4_K
llama-quantize ~/AI/GLM-4.5/GLM-4_5-BF16/GLM-4_5-BF16-00001-of-01762.gguf ~/AI/GLM-4.5/GLM-4_5-Q4_K/GLM-4_5-Q4_K.gguf q4_K 8

Basic instructions to lauch llama-server:

llama-server -m ~/AI/GLM-4.5/GLM-4_5-Q4_K/GLM-4_5-Q4_K.gguf -fa \
  -amb 1024 \
  -fmoe \
  -ctk f16 \
  -c 4096 \
  -ngl 99 \
  -ot "blk\.([0-9]|1[0-5])\.ffn_.*=CUDA0" \
  -ot exps=CPU \
  -b 2048 -ub 1024 \
  --warmup-batch \
  --no-mmap \
  --threads 8 \
  --main-gpu 0

Quant mix recipe for the adventurous:

Windows builds available at: https://github.com/Thireus/ik_llama.cpp/releases

@Thireus Thireus mentioned this pull request Aug 1, 2025
4 tasks
Copy link
Contributor

@ubergarm ubergarm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just two quick observations and trying to better understand how you ported the mainline PR as it is a bit different (did you make some changes, or was this port from a couple days ago etc).

Anyway, thanks for opening a clean PR to make it easier to look at the differences!

// Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel.
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) {
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test perplexity without this to see if it runs clean? Usually only add a model in here if perplexity is throwing nans without it.

Copy link
Contributor Author

@Thireus Thireus Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At that time the model was still not working (or I had not tried llama-server); so I was trying my best to find any possible change I could make to make the model work, without putting much thoughts into it. I had noticed you made those changes in one of your old PR, so I followed your steps. It is entirely possible that this change may not be required after all.

src/llama.cpp Outdated
}

// crop output on last layer
if (il == n_layer - 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if is slightly different than mainline: https://github.com/ggml-org/llama.cpp/pull/14939/files#diff-36e262e316ec1404e29880eb8b8ce4660ac584f0d0434710efc48a66497bdb59R13557

On mainline it is:

if (il == n_layer - 1 && inp_out_ids) {

I haven't followed the mainline PR closely enough to see if this changed recently nor looked in depth yet at what inp_out_ids even is supposed to be over there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 1, 2025

Just two quick observations and trying to better understand how you ported the mainline PR as it is a bit different (did you make some changes, or was this port from a couple days ago etc).

Anyway, thanks for opening a clean PR to make it easier to look at the differences!

This is a port from a few days ago indeed - to the best of my knowledge it was https://github.com/ggml-org/llama.cpp/pull/14939/files/fae4df8ee05c57bf2e2d5e3f0b094010f34dc86f. With slight variations as I was trying to best match existing implementations such as DeepSeek to make the port compatible with ik_llama.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 1, 2025

@ubergarm - /nothink seems to be working just fine. I'm not witnessing the issues they are reporting on the llama.cpp implementation.

@ubergarm
Copy link
Contributor

ubergarm commented Aug 1, 2025

My other thought about chat template stuff is after running gguf dump i see:

$ python gguf-py/scripts/gguf_dump.py /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/GLM-4.5-Thireus-Q8_0.gguf
...
     44: STRING     |        1 | tokenizer.chat_template = '[gMASK]<sop>\n{%- if tools -%}\n<|system|>\n# Tools\n\nYou may ca'

This looks right with what the official chat template is here: https://huggingface.co/zai-org/GLM-4.5?chat_template=default

When starting the model with llama-server it says:

INFO [                    main] chat template | tid="123750572615872" timestamp=1754072872 chat_example="<|system|>\nYou are a helpful assistant\n<|user|>\nHello\n<|assistant|>\nHi there\n<|user|>\nHow are you?\n<|assistant|>\n" built_in=true

So for /chat/completions/ API endpoint it seems like it will be properly applying the basic template, however I haven't run their jinja through the decoder script (i gotta find my notes hah) and made sure it matches up closely enough. But in limited testing llama-server seems to be working okay.

I'll upload your imatrix soon after it finishes to here: https://huggingface.co/ubergarm/GLM-4.5-GGUF

Once again though, I'm hesitant to release quants just yet until things are more tested, but excited to see this coming along! Thanks for your efforts!

@ubergarm
Copy link
Contributor

ubergarm commented Aug 1, 2025

I'll already have GLM-4.5-Air bf16 safetensors downloaded, so will try this PR going through the whole process and see how it comes out for testing.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 1, 2025

I'll already have GLM-4.5-Air bf16 safetensors downloaded, so will try this PR going through the whole process and see how it comes out for testing.

I was about to that that this is only tested on GLM-4.5, not the others. Good thinking!

@ddh0
Copy link

ddh0 commented Aug 1, 2025

Also cc @AesSedai

@ubergarm
Copy link
Contributor

ubergarm commented Aug 1, 2025

@Thireus

Your imatrix is uploaded here: https://huggingface.co/ubergarm/GLM-4.5-GGUF/blob/main/imatrix-GLM-4.5-BF16.dat

Here are the logs of the run including --layer-similarity if you are interested:

👈 imatrix logs and command
#!/usr/bin/env bash

ulimit -n 9999

# echo 0 | sudo tee /proc/sys/kernel/numa_balancing
# sudo sync; echo 3 | sudo tee /proc/sys/vm/drop_caches

model=/mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/GLM-4.5-THIREUS-BF16-SPECIAL_TENSOR-00001-of-01762.gguf

numactl -N 1 -m 1 \
./build/bin/llama-imatrix \
    -m "$model" \
    -fa \
    -f ubergarm-imatrix-calibration-corpus-v02.txt \
    -o /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat \
    --verbosity 1 \
    --layer-similarity \
    --seed 1337 \
    --ctx-size 512 \
    -ub 4096 -b 4096 \
    --numa numactl \
    --threads 128 \
    --threads-batch 192 \
    --no-mmap

llama_model_loader: loaded meta data with 44 key-value pairs and 1761 tensors from /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/GLM-4.5-THIREUS-BF16-SPECIAL_TENSOR-00001-of-01762.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = glm4moe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = GLM 4.5
llama_model_loader: - kv   3:                            general.version str              = 4.5
llama_model_loader: - kv   4:                           general.basename str              = GLM
llama_model_loader: - kv   5:                         general.size_label str              = 160x21B
llama_model_loader: - kv   6:                            general.license str              = mit
llama_model_loader: - kv   7:                               general.tags arr[str,1]       = ["text-generation"]
llama_model_loader: - kv   8:                          general.languages arr[str,2]       = ["en", "zh"]
llama_model_loader: - kv   9:                        glm4moe.block_count u32              = 93
llama_model_loader: - kv  10:                     glm4moe.context_length u32              = 131072
llama_model_loader: - kv  11:                   glm4moe.embedding_length u32              = 5120
llama_model_loader: - kv  12:                glm4moe.feed_forward_length u32              = 12288
llama_model_loader: - kv  13:               glm4moe.attention.head_count u32              = 96
llama_model_loader: - kv  14:            glm4moe.attention.head_count_kv u32              = 8
llama_model_loader: - kv  15:                     glm4moe.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  16:   glm4moe.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  17:                  glm4moe.expert_used_count u32              = 8
llama_model_loader: - kv  18:               glm4moe.attention.key_length u32              = 128
llama_model_loader: - kv  19:             glm4moe.attention.value_length u32              = 128
llama_model_loader: - kv  20:                          general.file_type u32              = 32
llama_model_loader: - kv  21:               glm4moe.rope.dimension_count u32              = 64
llama_model_loader: - kv  22:                       glm4moe.expert_count u32              = 160
llama_model_loader: - kv  23:         glm4moe.expert_feed_forward_length u32              = 1536
llama_model_loader: - kv  24:                glm4moe.expert_shared_count u32              = 1
llama_model_loader: - kv  25:          glm4moe.leading_dense_block_count u32              = 3
llama_model_loader: - kv  26:                 glm4moe.expert_gating_func u32              = 2
llama_model_loader: - kv  27:               glm4moe.expert_weights_scale f32              = 2.500000
llama_model_loader: - kv  28:                glm4moe.expert_weights_norm bool             = true
llama_model_loader: - kv  29:               general.quantization_version u32              = 2
llama_model_loader: - kv  30:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  31:                         tokenizer.ggml.pre str              = glm4
llama_model_loader: - kv  32:                      tokenizer.ggml.tokens arr[str,151552]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  33:                  tokenizer.ggml.token_type arr[i32,151552]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  34:                      tokenizer.ggml.merges arr[str,318088]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  35:                tokenizer.ggml.eos_token_id u32              = 151329
llama_model_loader: - kv  36:            tokenizer.ggml.padding_token_id u32              = 151329
llama_model_loader: - kv  37:                tokenizer.ggml.eot_token_id u32              = 151336
llama_model_loader: - kv  38:            tokenizer.ggml.unknown_token_id u32              = 151329
llama_model_loader: - kv  39:                tokenizer.ggml.bos_token_id u32              = 151329
llama_model_loader: - kv  40:                    tokenizer.chat_template str              = [gMASK]<sop>\n{%- if tools -%}\n<|syste...
llama_model_loader: - kv  41:                                   split.no u16              = 0
llama_model_loader: - kv  42:                                split.count u16              = 1762
llama_model_loader: - kv  43:                        split.tensors.count i32              = 1761
llama_model_loader: - type  f32:  838 tensors
llama_model_loader: - type bf16:  923 tensors
llm_load_vocab: special tokens cache size = 36
llm_load_vocab: token to piece cache size = 0.9713 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = glm4moe
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 151552
llm_load_print_meta: n_merges         = 318088
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 131072
llm_load_print_meta: n_embd           = 5120
llm_load_print_meta: n_layer          = 93
llm_load_print_meta: n_head           = 96
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_swa_pattern    = 1
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 12
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 12288
llm_load_print_meta: n_expert         = 160
llm_load_print_meta: n_expert_used    = 8
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 2
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 131072
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 355B.A32B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 358.338 B
llm_load_print_meta: model size       = 670.586 GiB (16.075 BPW) 
llm_load_print_meta: repeating layers = 667.696 GiB (16.075 BPW, 356.786 B parameters)
llm_load_print_meta: general.name     = GLM 4.5
llm_load_print_meta: BOS token        = 151329 '<|endoftext|>'
llm_load_print_meta: EOS token        = 151329 '<|endoftext|>'
llm_load_print_meta: UNK token        = 151329 '<|endoftext|>'
llm_load_print_meta: PAD token        = 151329 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 151336 '<|user|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: ggml ctx size =    0.72 MiB
llm_load_tensors:        CPU buffer size = 686680.19 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: mla_attn   = 0
llama_new_context_with_model: attn_max_b = 0
llama_new_context_with_model: fused_moe  = 0
llama_new_context_with_model: ser        = -1, 0
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =   186.00 MiB
llama_new_context_with_model: KV self size  =  186.00 MiB, K (f16):   93.00 MiB, V (f16):   93.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.58 MiB
llama_new_context_with_model:        CPU compute buffer size =   306.00 MiB
llama_new_context_with_model: graph nodes  = 4779
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 128 (n_threads_batch = 192) / 768 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 901.12 ms
compute_imatrix: computing over 814 chunks with batch_size 512
compute_imatrix: 10.08 seconds per pass - ETA 2 hours 16.72 minutes
[1]16.8092,[2]6.7403,[3]4.3630,[4]3.1866,[5]2.5865,[6]2.2122,[7]1.9898,[8]1.8430,[9]1.8314,
save_imatrix: entry '             blk.92.ffn_gate_exps.weight' has partial data (98.75%) 2 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.51.ffn_down_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.48.ffn_gate_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '               blk.51.ffn_up_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '               blk.30.ffn_up_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.29.ffn_down_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.29.ffn_gate_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '               blk.29.ffn_up_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.48.ffn_down_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '               blk.92.ffn_up_exps.weight' has partial data (98.75%) 2 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.30.ffn_gate_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.30.ffn_down_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.51.ffn_gate_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.92.ffn_down_exps.weight' has partial data (98.75%) 2 out of 160 experts are missing data Storing **but be aware**
save_imatrix: entry '               blk.48.ffn_up_exps.weight' has partial data (99.38%) 1 out of 160 experts are missing data Storing **but be aware**

save_imatrix: stored collected data after 10 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[10]1.7538,[11]1.8628,[12]1.9497,[13]2.0128,[14]2.0758,[15]1.9792,[16]1.9022,[17]1.8481,[18]1.7962,[19]1.7420,
save_imatrix: stored collected data after 20 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[20]1.7069,[21]1.6655,[22]1.6390,[23]1.6079,[24]1.5804,[25]1.5523,[26]1.6324,[27]1.7262,[28]1.8429,[29]1.8173,
save_imatrix: stored collected data after 30 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[30]1.8001,[31]1.8139,[32]1.8058,[33]1.8766,[34]1.8526,[35]1.8476,[36]1.8334,[37]1.8276,[38]1.8529,[39]1.8663,
save_imatrix: stored collected data after 40 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[40]1.8546,[41]1.8837,[42]1.8914,[43]1.9060,[44]1.9156,[45]1.9178,[46]1.9054,[47]1.9153,[48]1.9143,[49]1.9182,
save_imatrix: stored collected data after 50 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[50]1.9079,[51]1.9235,[52]1.9356,[53]1.9249,[54]1.9343,[55]1.9362,[56]1.9417,[57]1.9357,[58]1.9853,[59]2.0340,
save_imatrix: stored collected data after 60 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[60]2.0820,[61]2.0979,[62]2.1400,[63]2.1680,[64]2.1632,[65]2.1640,[66]2.1671,[67]2.1524,[68]2.1657,[69]2.2064,
save_imatrix: stored collected data after 70 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[70]2.2582,[71]2.2854,[72]2.3223,[73]2.3531,[74]2.3716,[75]2.3998,[76]2.4148,[77]2.4411,[78]2.4389,[79]2.4221,
save_imatrix: stored collected data after 80 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[80]2.4199,[81]2.4225,[82]2.4477,[83]2.4895,[84]2.5059,[85]2.5095,[86]2.5137,[87]2.5037,[88]2.5068,[89]2.4945,
save_imatrix: stored collected data after 90 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[90]2.4849,[91]2.4808,[92]2.4657,[93]2.4469,[94]2.4741,[95]2.5218,[96]2.5412,[97]2.5434,[98]2.5515,[99]2.5706,
save_imatrix: stored collected data after 100 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[100]2.5892,[101]2.5974,[102]2.5991,[103]2.6306,[104]2.6540,[105]2.6476,[106]2.6845,[107]2.7273,[108]2.7577,[109]2.7980,
save_imatrix: stored collected data after 110 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[110]2.8276,[111]2.8644,[112]2.8963,[113]2.8919,[114]2.9083,[115]2.9220,[116]2.9303,[117]2.9364,[118]2.9665,[119]3.0045,
save_imatrix: stored collected data after 120 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[120]3.0421,[121]3.0367,[122]3.0118,[123]2.9970,[124]3.0177,[125]3.0060,[126]2.9822,[127]2.9828,[128]2.9806,[129]2.9875,
save_imatrix: stored collected data after 130 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[130]2.9961,[131]3.0125,[132]3.0298,[133]3.0353,[134]3.0729,[135]3.0906,[136]3.0662,[137]3.0415,[138]3.0195,[139]2.9963,
save_imatrix: stored collected data after 140 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[140]3.0054,[141]3.0137,[142]3.0481,[143]3.0755,[144]3.0806,[145]3.0999,[146]3.1236,[147]3.1461,[148]3.1797,[149]3.2059,
save_imatrix: stored collected data after 150 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[150]3.2329,[151]3.2507,[152]3.2723,[153]3.2842,[154]3.2903,[155]3.2860,[156]3.3026,[157]3.3125,[158]3.3244,[159]3.3376,
save_imatrix: stored collected data after 160 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[160]3.3521,[161]3.3566,[162]3.3596,[163]3.3730,[164]3.3788,[165]3.3865,[166]3.3999,[167]3.4016,[168]3.4032,[169]3.4102,
save_imatrix: stored collected data after 170 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[170]3.4193,[171]3.4246,[172]3.4296,[173]3.4358,[174]3.4532,[175]3.4650,[176]3.4702,[177]3.4760,[178]3.4935,[179]3.4819,
save_imatrix: stored collected data after 180 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[180]3.4896,[181]3.5034,[182]3.5243,[183]3.5399,[184]3.5470,[185]3.5495,[186]3.5481,[187]3.5473,[188]3.5489,[189]3.5499,
save_imatrix: stored collected data after 190 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[190]3.5511,[191]3.5486,[192]3.5701,[193]3.5991,[194]3.6236,[195]3.6501,[196]3.6697,[197]3.7037,[198]3.7149,[199]3.7318,
save_imatrix: stored collected data after 200 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[200]3.7226,[201]3.7382,[202]3.7311,[203]3.7091,[204]3.6876,[205]3.7082,[206]3.7229,[207]3.7318,[208]3.7417,[209]3.7609,
save_imatrix: stored collected data after 210 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[210]3.7771,[211]3.7935,[212]3.8121,[213]3.8290,[214]3.8321,[215]3.8110,[216]3.7881,[217]3.7656,[218]3.7432,[219]3.7212,
save_imatrix: stored collected data after 220 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[220]3.7055,[221]3.7036,[222]3.6939,[223]3.6892,[224]3.6760,[225]3.6579,[226]3.6574,[227]3.6620,[228]3.6838,[229]3.7077,
save_imatrix: stored collected data after 230 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[230]3.7183,[231]3.7402,[232]3.7358,[233]3.7603,[234]3.7894,[235]3.8015,[236]3.8160,[237]3.8175,[238]3.8411,[239]3.8694,
save_imatrix: stored collected data after 240 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[240]3.8659,[241]3.8761,[242]3.8903,[243]3.9100,[244]3.9294,[245]3.9438,[246]3.9562,[247]3.9656,[248]3.9554,[249]3.9816,
save_imatrix: stored collected data after 250 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[250]3.9954,[251]4.0137,[252]4.0246,[253]4.0289,[254]4.0353,[255]4.0385,[256]4.0511,[257]4.0559,[258]4.0678,[259]4.0822,
save_imatrix: stored collected data after 260 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[260]4.0923,[261]4.1029,[262]4.1146,[263]4.1297,[264]4.1417,[265]4.1577,[266]4.1437,[267]4.1496,[268]4.1544,[269]4.1681,
save_imatrix: stored collected data after 270 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[270]4.1888,[271]4.2028,[272]4.2240,[273]4.2244,[274]4.2225,[275]4.2310,[276]4.2363,[277]4.2530,[278]4.2672,[279]4.2808,
save_imatrix: stored collected data after 280 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[280]4.2902,[281]4.2913,[282]4.3053,[283]4.3167,[284]4.3194,[285]4.3358,[286]4.3368,[287]4.3410,[288]4.3498,[289]4.3473,
save_imatrix: stored collected data after 290 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[290]4.3583,[291]4.3637,[292]4.3700,[293]4.3864,[294]4.3991,[295]4.4124,[296]4.4292,[297]4.4341,[298]4.4530,[299]4.4657,
save_imatrix: stored collected data after 300 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[300]4.4822,[301]4.4951,[302]4.5095,[303]4.5144,[304]4.5336,[305]4.5411,[306]4.5451,[307]4.5536,[308]4.5710,[309]4.5813,
save_imatrix: stored collected data after 310 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[310]4.5858,[311]4.5939,[312]4.6025,[313]4.6165,[314]4.6244,[315]4.6338,[316]4.6466,[317]4.6587,[318]4.6726,[319]4.6779,
save_imatrix: stored collected data after 320 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[320]4.6821,[321]4.6755,[322]4.6849,[323]4.6688,[324]4.6853,[325]4.6888,[326]4.6671,[327]4.6807,[328]4.6914,[329]4.6975,
save_imatrix: stored collected data after 330 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[330]4.7043,[331]4.7033,[332]4.7074,[333]4.7257,[334]4.7232,[335]4.7341,[336]4.7505,[337]4.7598,[338]4.7639,[339]4.7518,
save_imatrix: stored collected data after 340 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[340]4.7632,[341]4.7790,[342]4.7941,[343]4.8108,[344]4.8320,[345]4.8604,[346]4.8630,[347]4.8644,[348]4.8671,[349]4.8744,
save_imatrix: stored collected data after 350 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[350]4.8872,[351]4.9059,[352]4.9061,[353]4.9034,[354]4.9135,[355]4.9099,[356]4.9097,[357]4.9080,[358]4.9038,[359]4.9090,
save_imatrix: stored collected data after 360 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[360]4.9194,[361]4.9156,[362]4.9129,[363]4.8963,[364]4.8796,[365]4.8635,[366]4.8503,[367]4.8318,[368]4.8157,[369]4.7996,
save_imatrix: stored collected data after 370 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[370]4.7866,[371]4.7714,[372]4.7556,[373]4.7429,[374]4.7300,[375]4.7123,[376]4.6990,[377]4.6849,[378]4.6693,[379]4.6561,
save_imatrix: stored collected data after 380 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[380]4.6537,[381]4.6400,[382]4.6338,[383]4.6367,[384]4.6232,[385]4.6175,[386]4.6067,[387]4.5890,[388]4.5728,[389]4.5652,
save_imatrix: stored collected data after 390 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[390]4.5555,[391]4.5414,[392]4.5239,[393]4.5066,[394]4.5040,[395]4.5026,[396]4.4983,[397]4.4885,[398]4.4904,[399]4.4896,
save_imatrix: stored collected data after 400 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[400]4.4737,[401]4.4599,[402]4.4532,[403]4.4396,[404]4.4289,[405]4.4190,[406]4.4097,[407]4.3939,[408]4.3788,[409]4.3650,
save_imatrix: stored collected data after 410 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[410]4.3531,[411]4.3425,[412]4.3370,[413]4.3287,[414]4.3245,[415]4.3204,[416]4.3186,[417]4.3135,[418]4.3088,[419]4.2949,
save_imatrix: stored collected data after 420 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[420]4.2807,[421]4.2662,[422]4.2536,[423]4.2399,[424]4.2282,[425]4.2152,[426]4.2014,[427]4.1906,[428]4.1766,[429]4.1644,
save_imatrix: stored collected data after 430 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[430]4.1522,[431]4.1426,[432]4.1317,[433]4.1220,[434]4.1199,[435]4.1181,[436]4.1119,[437]4.1020,[438]4.0960,[439]4.0828,
save_imatrix: stored collected data after 440 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[440]4.0702,[441]4.0585,[442]4.0473,[443]4.0362,[444]4.0323,[445]4.0235,[446]4.0202,[447]4.0146,[448]4.0049,[449]4.0016,
save_imatrix: stored collected data after 450 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[450]3.9946,[451]3.9874,[452]3.9767,[453]3.9699,[454]3.9632,[455]3.9550,[456]3.9433,[457]3.9322,[458]3.9206,[459]3.9096,
save_imatrix: stored collected data after 460 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[460]3.8989,[461]3.8899,[462]3.8810,[463]3.8748,[464]3.8676,[465]3.8636,[466]3.8585,[467]3.8536,[468]3.8486,[469]3.8434,
save_imatrix: stored collected data after 470 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[470]3.8384,[471]3.8334,[472]3.8284,[473]3.8242,[474]3.8191,[475]3.8139,[476]3.8096,[477]3.8047,[478]3.7999,[479]3.7961,
save_imatrix: stored collected data after 480 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[480]3.7857,[481]3.7764,[482]3.7724,[483]3.7656,[484]3.7584,[485]3.7486,[486]3.7395,[487]3.7308,[488]3.7220,[489]3.7158,
save_imatrix: stored collected data after 490 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[490]3.7084,[491]3.7021,[492]3.6979,[493]3.6921,[494]3.6859,[495]3.6784,[496]3.6778,[497]3.6749,[498]3.6698,[499]3.6680,
save_imatrix: stored collected data after 500 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[500]3.6647,[501]3.6636,[502]3.6639,[503]3.6667,[504]3.6647,[505]3.6596,[506]3.6523,[507]3.6562,[508]3.6667,[509]3.6749,
save_imatrix: stored collected data after 510 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[510]3.6826,[511]3.6895,[512]3.6970,[513]3.7014,[514]3.7053,[515]3.7070,[516]3.7142,[517]3.7171,[518]3.7236,[519]3.7323,
save_imatrix: stored collected data after 520 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[520]3.7451,[521]3.7605,[522]3.7736,[523]3.7724,[524]3.7795,[525]3.7835,[526]3.7893,[527]3.7909,[528]3.7931,[529]3.8026,
save_imatrix: stored collected data after 530 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[530]3.8079,[531]3.8093,[532]3.8157,[533]3.8216,[534]3.8286,[535]3.8288,[536]3.8280,[537]3.8288,[538]3.8326,[539]3.8370,
save_imatrix: stored collected data after 540 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[540]3.8417,[541]3.8458,[542]3.8479,[543]3.8506,[544]3.8552,[545]3.8600,[546]3.8687,[547]3.8770,[548]3.8832,[549]3.8919,
save_imatrix: stored collected data after 550 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[550]3.8987,[551]3.9068,[552]3.9131,[553]3.9191,[554]3.9255,[555]3.9314,[556]3.9288,[557]3.9261,[558]3.9228,[559]3.9276,
save_imatrix: stored collected data after 560 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[560]3.9331,[561]3.9378,[562]3.9433,[563]3.9446,[564]3.9491,[565]3.9497,[566]3.9542,[567]3.9548,[568]3.9547,[569]3.9539,
save_imatrix: stored collected data after 570 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[570]3.9546,[571]3.9576,[572]3.9542,[573]3.9509,[574]3.9462,[575]3.9421,[576]3.9350,[577]3.9290,[578]3.9228,[579]3.9166,
save_imatrix: stored collected data after 580 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[580]3.9139,[581]3.9150,[582]3.9130,[583]3.9138,[584]3.9120,[585]3.9120,[586]3.9116,[587]3.9092,[588]3.9037,[589]3.9042,
save_imatrix: stored collected data after 590 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[590]3.9009,[591]3.8938,[592]3.8872,[593]3.8797,[594]3.8743,[595]3.8708,[596]3.8695,[597]3.8677,[598]3.8667,[599]3.8645,
save_imatrix: stored collected data after 600 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[600]3.8599,[601]3.8535,[602]3.8530,[603]3.8530,[604]3.8531,[605]3.8489,[606]3.8471,[607]3.8441,[608]3.8471,[609]3.8458,
save_imatrix: stored collected data after 610 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[610]3.8437,[611]3.8437,[612]3.8425,[613]3.8375,[614]3.8309,[615]3.8233,[616]3.8164,[617]3.8088,[618]3.8017,[619]3.7944,
save_imatrix: stored collected data after 620 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[620]3.7875,[621]3.7793,[622]3.7714,[623]3.7645,[624]3.7581,[625]3.7505,[626]3.7442,[627]3.7367,[628]3.7299,[629]3.7234,
save_imatrix: stored collected data after 630 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[630]3.7168,[631]3.7100,[632]3.7058,[633]3.6985,[634]3.6944,[635]3.6929,[636]3.6877,[637]3.6811,[638]3.6755,[639]3.6690,
save_imatrix: stored collected data after 640 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[640]3.6621,[641]3.6558,[642]3.6496,[643]3.6435,[644]3.6372,[645]3.6309,[646]3.6249,[647]3.6194,[648]3.6184,[649]3.6123,
save_imatrix: stored collected data after 650 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[650]3.6055,[651]3.5992,[652]3.5929,[653]3.5864,[654]3.5796,[655]3.5730,[656]3.5668,[657]3.5611,[658]3.5547,[659]3.5570,
save_imatrix: stored collected data after 660 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[660]3.5573,[661]3.5604,[662]3.5587,[663]3.5528,[664]3.5491,[665]3.5438,[666]3.5377,[667]3.5324,[668]3.5275,[669]3.5226,
save_imatrix: stored collected data after 670 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[670]3.5177,[671]3.5125,[672]3.5068,[673]3.5011,[674]3.4970,[675]3.4922,[676]3.4869,[677]3.4818,[678]3.4765,[679]3.4708,
save_imatrix: stored collected data after 680 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[680]3.4667,[681]3.4611,[682]3.4562,[683]3.4514,[684]3.4459,[685]3.4411,[686]3.4390,[687]3.4378,[688]3.4343,[689]3.4294,
save_imatrix: stored collected data after 690 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[690]3.4236,[691]3.4176,[692]3.4125,[693]3.4072,[694]3.4033,[695]3.4014,[696]3.3996,[697]3.3970,[698]3.3955,[699]3.3934,
save_imatrix: stored collected data after 700 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[700]3.3914,[701]3.3899,[702]3.3879,[703]3.3860,[704]3.3840,[705]3.3822,[706]3.3807,[707]3.3782,[708]3.3768,[709]3.3745,
save_imatrix: stored collected data after 710 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[710]3.3728,[711]3.3708,[712]3.3714,[713]3.3711,[714]3.3711,[715]3.3720,[716]3.3730,[717]3.3740,[718]3.3748,[719]3.3762,
save_imatrix: stored collected data after 720 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[720]3.3783,[721]3.3785,[722]3.3791,[723]3.3798,[724]3.3814,[725]3.3823,[726]3.3842,[727]3.3854,[728]3.3873,[729]3.3872,
save_imatrix: stored collected data after 730 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[730]3.3873,[731]3.3885,[732]3.3909,[733]3.3917,[734]3.3921,[735]3.3923,[736]3.3934,[737]3.3951,[738]3.3954,[739]3.3978,
save_imatrix: stored collected data after 740 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[740]3.3994,[741]3.4013,[742]3.4027,[743]3.4031,[744]3.4030,[745]3.4037,[746]3.4056,[747]3.4068,[748]3.4084,[749]3.4093,
save_imatrix: stored collected data after 750 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[750]3.4106,[751]3.4113,[752]3.4133,[753]3.4155,[754]3.4160,[755]3.4166,[756]3.4181,[757]3.4199,[758]3.4207,[759]3.4219,
save_imatrix: stored collected data after 760 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[760]3.4226,[761]3.4234,[762]3.4251,[763]3.4255,[764]3.4273,[765]3.4286,[766]3.4300,[767]3.4306,[768]3.4316,[769]3.4319,
save_imatrix: stored collected data after 770 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[770]3.4328,[771]3.4347,[772]3.4354,[773]3.4356,[774]3.4361,[775]3.4380,[776]3.4390,[777]3.4410,[778]3.4409,[779]3.4423,
save_imatrix: stored collected data after 780 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[780]3.4438,[781]3.4454,[782]3.4471,[783]3.4492,[784]3.4495,[785]3.4502,[786]3.4508,[787]3.4524,[788]3.4528,[789]3.4550,
save_imatrix: stored collected data after 790 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[790]3.4564,[791]3.4576,[792]3.4577,[793]3.4588,[794]3.4610,[795]3.4623,[796]3.4624,[797]3.4635,[798]3.4646,[799]3.4681,
save_imatrix: stored collected data after 800 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[800]3.4686,[801]3.4682,[802]3.4700,[803]3.4718,[804]3.4725,[805]3.4733,[806]3.4738,[807]3.4747,[808]3.4753,[809]3.4761,
save_imatrix: stored collected data after 810 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat
[810]3.4778,[811]3.4796,[812]3.4808,[813]3.4815,[814]3.4820,
save_imatrix: stored collected data after 814 chunks in /mnt/data/models/Thireus/GLM-4.5-THIREUS-BF16-SPECIAL_SPLIT/imatrix-GLM-4.5-BF16.dat

Final estimate: PPL = 3.4820 +/- 0.01710

llama_print_timings:        load time =  252304.21 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time = 7778409.55 ms / 416768 tokens (   18.66 ms per token,    53.58 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time = 8047122.96 ms / 416769 tokens

======================== sorted layer importances
  0: Layer   0, <cos_sim> = 0.427685
  1: Layer   2, <cos_sim> = 0.754953
  2: Layer   1, <cos_sim> = 0.761973
  3: Layer   3, <cos_sim> = 0.859269
  4: Layer  91, <cos_sim> = 0.884954
  5: Layer   4, <cos_sim> = 0.903777
  6: Layer  32, <cos_sim> = 0.909284
  7: Layer   6, <cos_sim> = 0.912099
  8: Layer  39, <cos_sim> = 0.916231
  9: Layer  37, <cos_sim> = 0.916806
 10: Layer  23, <cos_sim> = 0.917485
 11: Layer  31, <cos_sim> = 0.917567
 12: Layer  41, <cos_sim> = 0.919794
 13: Layer  40, <cos_sim> = 0.921433
 14: Layer  33, <cos_sim> = 0.921658
 15: Layer  29, <cos_sim> = 0.921735
 16: Layer  30, <cos_sim> = 0.921975
 17: Layer  22, <cos_sim> = 0.923691
 18: Layer  14, <cos_sim> = 0.923877
 19: Layer  28, <cos_sim> = 0.923969
 20: Layer  38, <cos_sim> = 0.923987
 21: Layer  24, <cos_sim> = 0.924143
 22: Layer  34, <cos_sim> = 0.925237
 23: Layer  26, <cos_sim> = 0.925394
 24: Layer   7, <cos_sim> = 0.925584
 25: Layer  36, <cos_sim> = 0.927248
 26: Layer  13, <cos_sim> = 0.9273
 27: Layer  25, <cos_sim> = 0.928089
 28: Layer  21, <cos_sim> = 0.92832
 29: Layer  85, <cos_sim> = 0.928453
 30: Layer  84, <cos_sim> = 0.929674
 31: Layer  27, <cos_sim> = 0.930059
 32: Layer  35, <cos_sim> = 0.930223
 33: Layer  10, <cos_sim> = 0.930323
 34: Layer   8, <cos_sim> = 0.931374
 35: Layer   9, <cos_sim> = 0.931623
 36: Layer  11, <cos_sim> = 0.932746
 37: Layer  92, <cos_sim> = 0.93573
 38: Layer  42, <cos_sim> = 0.93659
 39: Layer  12, <cos_sim> = 0.938836
 40: Layer   5, <cos_sim> = 0.941925
 41: Layer  43, <cos_sim> = 0.944567
 42: Layer  86, <cos_sim> = 0.944959
 43: Layer  15, <cos_sim> = 0.947795
 44: Layer  20, <cos_sim> = 0.94999
 45: Layer  18, <cos_sim> = 0.952245
 46: Layer  83, <cos_sim> = 0.953255
 47: Layer  19, <cos_sim> = 0.953365
 48: Layer  44, <cos_sim> = 0.953818
 49: Layer  45, <cos_sim> = 0.953992
 50: Layer  17, <cos_sim> = 0.95693
 51: Layer  16, <cos_sim> = 0.957523
 52: Layer  46, <cos_sim> = 0.958581
 53: Layer  87, <cos_sim> = 0.958638
 54: Layer  80, <cos_sim> = 0.958765
 55: Layer  90, <cos_sim> = 0.959166
 56: Layer  81, <cos_sim> = 0.959806
 57: Layer  82, <cos_sim> = 0.961588
 58: Layer  47, <cos_sim> = 0.961593
 59: Layer  89, <cos_sim> = 0.961599
 60: Layer  88, <cos_sim> = 0.962237
 61: Layer  79, <cos_sim> = 0.9637
 62: Layer  48, <cos_sim> = 0.964249
 63: Layer  50, <cos_sim> = 0.965041
 64: Layer  49, <cos_sim> = 0.965927
 65: Layer  51, <cos_sim> = 0.966417
 66: Layer  54, <cos_sim> = 0.968401
 67: Layer  52, <cos_sim> = 0.968414
 68: Layer  76, <cos_sim> = 0.970006
 69: Layer  53, <cos_sim> = 0.970636
 70: Layer  55, <cos_sim> = 0.971746
 71: Layer  78, <cos_sim> = 0.971977
 72: Layer  75, <cos_sim> = 0.972703
 73: Layer  77, <cos_sim> = 0.976101
 74: Layer  58, <cos_sim> = 0.978261
 75: Layer  56, <cos_sim> = 0.978325
 76: Layer  57, <cos_sim> = 0.979194
 77: Layer  59, <cos_sim> = 0.979934
 78: Layer  73, <cos_sim> = 0.980678
 79: Layer  67, <cos_sim> = 0.981173
 80: Layer  66, <cos_sim> = 0.981491
 81: Layer  72, <cos_sim> = 0.981982
 82: Layer  68, <cos_sim> = 0.982146
 83: Layer  65, <cos_sim> = 0.982195
 84: Layer  61, <cos_sim> = 0.982205
 85: Layer  74, <cos_sim> = 0.982215
 86: Layer  60, <cos_sim> = 0.982433
 87: Layer  71, <cos_sim> = 0.982878
 88: Layer  63, <cos_sim> = 0.983266
 89: Layer  70, <cos_sim> = 0.98376
 90: Layer  64, <cos_sim> = 0.984047
 91: Layer  69, <cos_sim> = 0.984223
 92: Layer  62, <cos_sim> = 0.984883

======================== sorted attention importances
  0: Layer   0, <cos_sim> = 0.344938
  1: Layer   1, <cos_sim> = 0.591517
  2: Layer   2, <cos_sim> = 0.654333
  3: Layer   3, <cos_sim> = 0.823345
  4: Layer   7, <cos_sim> = 0.838734
  5: Layer  13, <cos_sim> = 0.848823
  6: Layer   4, <cos_sim> = 0.853067
  7: Layer   6, <cos_sim> = 0.85565
  8: Layer   9, <cos_sim> = 0.859293
  9: Layer   8, <cos_sim> = 0.865704
 10: Layer  15, <cos_sim> = 0.874006
 11: Layer  12, <cos_sim> = 0.876242
 12: Layer   5, <cos_sim> = 0.879342
 13: Layer  10, <cos_sim> = 0.882239
 14: Layer  11, <cos_sim> = 0.882394
 15: Layer  17, <cos_sim> = 0.892329
 16: Layer  16, <cos_sim> = 0.893135
 17: Layer  21, <cos_sim> = 0.897359
 18: Layer  14, <cos_sim> = 0.900934
 19: Layer  19, <cos_sim> = 0.90725
 20: Layer  20, <cos_sim> = 0.913042
 21: Layer  18, <cos_sim> = 0.915475
 22: Layer  23, <cos_sim> = 0.921054
 23: Layer  22, <cos_sim> = 0.928428
 24: Layer  24, <cos_sim> = 0.932289
 25: Layer  25, <cos_sim> = 0.937948
 26: Layer  32, <cos_sim> = 0.941675
 27: Layer  28, <cos_sim> = 0.942764
 28: Layer  27, <cos_sim> = 0.943415
 29: Layer  26, <cos_sim> = 0.944893
 30: Layer  33, <cos_sim> = 0.94565
 31: Layer  31, <cos_sim> = 0.947089
 32: Layer  30, <cos_sim> = 0.947855
 33: Layer  37, <cos_sim> = 0.948046
 34: Layer  38, <cos_sim> = 0.950758
 35: Layer  39, <cos_sim> = 0.951663
 36: Layer  35, <cos_sim> = 0.952977
 37: Layer  41, <cos_sim> = 0.953607
 38: Layer  29, <cos_sim> = 0.954196
 39: Layer  40, <cos_sim> = 0.954582
 40: Layer  34, <cos_sim> = 0.954729
 41: Layer  42, <cos_sim> = 0.954918
 42: Layer  36, <cos_sim> = 0.959325
 43: Layer  85, <cos_sim> = 0.962455
 44: Layer  43, <cos_sim> = 0.964714
 45: Layer  44, <cos_sim> = 0.965964
 46: Layer  45, <cos_sim> = 0.967585
 47: Layer  46, <cos_sim> = 0.970301
 48: Layer  86, <cos_sim> = 0.970694
 49: Layer  84, <cos_sim> = 0.971816
 50: Layer  51, <cos_sim> = 0.974334
 51: Layer  52, <cos_sim> = 0.974995
 52: Layer  83, <cos_sim> = 0.975456
 53: Layer  92, <cos_sim> = 0.977125
 54: Layer  50, <cos_sim> = 0.978028
 55: Layer  47, <cos_sim> = 0.978144
 56: Layer  81, <cos_sim> = 0.978666
 57: Layer  48, <cos_sim> = 0.978749
 58: Layer  49, <cos_sim> = 0.978877
 59: Layer  82, <cos_sim> = 0.978997
 60: Layer  53, <cos_sim> = 0.979392
 61: Layer  80, <cos_sim> = 0.979488
 62: Layer  91, <cos_sim> = 0.980023
 63: Layer  87, <cos_sim> = 0.981837
 64: Layer  58, <cos_sim> = 0.981972
 65: Layer  54, <cos_sim> = 0.982405
 66: Layer  79, <cos_sim> = 0.98302
 67: Layer  88, <cos_sim> = 0.983417
 68: Layer  78, <cos_sim> = 0.984292
 69: Layer  89, <cos_sim> = 0.986245
 70: Layer  90, <cos_sim> = 0.986536
 71: Layer  61, <cos_sim> = 0.986631
 72: Layer  68, <cos_sim> = 0.986747
 73: Layer  59, <cos_sim> = 0.986949
 74: Layer  71, <cos_sim> = 0.987292
 75: Layer  73, <cos_sim> = 0.987331
 76: Layer  55, <cos_sim> = 0.987493
 77: Layer  72, <cos_sim> = 0.987677
 78: Layer  76, <cos_sim> = 0.98837
 79: Layer  57, <cos_sim> = 0.988642
 80: Layer  56, <cos_sim> = 0.988706
 81: Layer  77, <cos_sim> = 0.988904
 82: Layer  67, <cos_sim> = 0.989027
 83: Layer  65, <cos_sim> = 0.989138
 84: Layer  70, <cos_sim> = 0.989192
 85: Layer  63, <cos_sim> = 0.989244
 86: Layer  74, <cos_sim> = 0.989437
 87: Layer  64, <cos_sim> = 0.989463
 88: Layer  66, <cos_sim> = 0.989895
 89: Layer  60, <cos_sim> = 0.989994
 90: Layer  69, <cos_sim> = 0.991278
 91: Layer  75, <cos_sim> = 0.991565
 92: Layer  62, <cos_sim> = 0.991979

======================== sorted ffn importances
  0: Layer   0, <cos_sim> = 0.624763
  1: Layer   1, <cos_sim> = 0.625883
  2: Layer   2, <cos_sim> = 0.758364
  3: Layer   6, <cos_sim> = 0.823684
  4: Layer   3, <cos_sim> = 0.82519
  5: Layer  14, <cos_sim> = 0.855503
  6: Layer   8, <cos_sim> = 0.857968
  7: Layer  11, <cos_sim> = 0.858688
  8: Layer   4, <cos_sim> = 0.860767
  9: Layer  12, <cos_sim> = 0.861034
 10: Layer   7, <cos_sim> = 0.866609
 11: Layer   5, <cos_sim> = 0.867507
 12: Layer  16, <cos_sim> = 0.880483
 13: Layer   9, <cos_sim> = 0.884367
 14: Layer  15, <cos_sim> = 0.884556
 15: Layer  10, <cos_sim> = 0.888465
 16: Layer  20, <cos_sim> = 0.889985
 17: Layer  91, <cos_sim> = 0.895348
 18: Layer  18, <cos_sim> = 0.897993
 19: Layer  19, <cos_sim> = 0.900205
 20: Layer  13, <cos_sim> = 0.900349
 21: Layer  24, <cos_sim> = 0.916468
 22: Layer  17, <cos_sim> = 0.916497
 23: Layer  22, <cos_sim> = 0.916572
 24: Layer  25, <cos_sim> = 0.919954
 25: Layer  26, <cos_sim> = 0.920522
 26: Layer  23, <cos_sim> = 0.921078
 27: Layer  27, <cos_sim> = 0.925478
 28: Layer  29, <cos_sim> = 0.930093
 29: Layer  21, <cos_sim> = 0.930119
 30: Layer  32, <cos_sim> = 0.933381
 31: Layer  31, <cos_sim> = 0.934207
 32: Layer  28, <cos_sim> = 0.935718
 33: Layer  30, <cos_sim> = 0.936237
 34: Layer  34, <cos_sim> = 0.938
 35: Layer  37, <cos_sim> = 0.940027
 36: Layer  36, <cos_sim> = 0.941134
 37: Layer  33, <cos_sim> = 0.943274
 38: Layer  39, <cos_sim> = 0.943999
 39: Layer  40, <cos_sim> = 0.94547
 40: Layer  35, <cos_sim> = 0.945489
 41: Layer  41, <cos_sim> = 0.94562
 42: Layer  38, <cos_sim> = 0.948466
 43: Layer  43, <cos_sim> = 0.952106
 44: Layer  42, <cos_sim> = 0.953932
 45: Layer  44, <cos_sim> = 0.954758
 46: Layer  45, <cos_sim> = 0.955495
 47: Layer  84, <cos_sim> = 0.957791
 48: Layer  46, <cos_sim> = 0.962356
 49: Layer  47, <cos_sim> = 0.963516
 50: Layer  85, <cos_sim> = 0.964269
 51: Layer  50, <cos_sim> = 0.964335
 52: Layer  90, <cos_sim> = 0.964703
 53: Layer  48, <cos_sim> = 0.96498
 54: Layer  51, <cos_sim> = 0.965182
 55: Layer  49, <cos_sim> = 0.965669
 56: Layer  52, <cos_sim> = 0.968839
 57: Layer  86, <cos_sim> = 0.968873
 58: Layer  89, <cos_sim> = 0.968923
 59: Layer  79, <cos_sim> = 0.971219
 60: Layer  83, <cos_sim> = 0.971772
 61: Layer  80, <cos_sim> = 0.972194
 62: Layer  87, <cos_sim> = 0.972445
 63: Layer  53, <cos_sim> = 0.972528
 64: Layer  81, <cos_sim> = 0.972955
 65: Layer  78, <cos_sim> = 0.973467
 66: Layer  57, <cos_sim> = 0.974059
 67: Layer  77, <cos_sim> = 0.974158
 68: Layer  82, <cos_sim> = 0.974329
 69: Layer  55, <cos_sim> = 0.974836
 70: Layer  88, <cos_sim> = 0.974857
 71: Layer  92, <cos_sim> = 0.975135
 72: Layer  54, <cos_sim> = 0.975202
 73: Layer  75, <cos_sim> = 0.975318
 74: Layer  76, <cos_sim> = 0.975601
 75: Layer  56, <cos_sim> = 0.976377
 76: Layer  58, <cos_sim> = 0.976683
 77: Layer  60, <cos_sim> = 0.977063
 78: Layer  73, <cos_sim> = 0.977821
 79: Layer  67, <cos_sim> = 0.97792
 80: Layer  72, <cos_sim> = 0.978052
 81: Layer  70, <cos_sim> = 0.978202
 82: Layer  59, <cos_sim> = 0.978237
 83: Layer  65, <cos_sim> = 0.978362
 84: Layer  71, <cos_sim> = 0.978628
 85: Layer  69, <cos_sim> = 0.978685
 86: Layer  66, <cos_sim> = 0.978833
 87: Layer  63, <cos_sim> = 0.979037
 88: Layer  64, <cos_sim> = 0.979078
 89: Layer  62, <cos_sim> = 0.979626
 90: Layer  74, <cos_sim> = 0.979673
 91: Layer  68, <cos_sim> = 0.979996
 92: Layer  61, <cos_sim> = 0.980596

@Thireus
Copy link
Contributor Author

Thireus commented Aug 1, 2025

Thank you so much @ubergarm!

@InfernalDread
Copy link

InfernalDread commented Aug 1, 2025

I apologize if this is off topic, but is it possible for a GLM-4.5-Air 4 bit GGUF quant to test as well, as it is the much friendlier option for resource limited individuals (such as myself).

Thank you all for your amazing and quick work!

@ubergarm
Copy link
Contributor

ubergarm commented Aug 1, 2025

I apologize if this is off topic, but is it possible for a GLM-4.5-Air 4 bit GGUF quant to test as well, as it is the much friendlier option for resource limited individuals (such as myself).

Having an issue converting the original bf16 safetensors using this PR to get my bf16 GGUF.
I changed two lines which gets me further:

@@ -3951,8 +3951,8 @@ class Dots1Model(Qwen2MoeModel):
             return [(self.map_tensor_name(name), data_torch)]
         return super().modify_tensors(data_torch, name, bid)

-@ModelBase.register("Glm4MoeForCausalLM")
-class Glm4MoeModel(TextModel):
+@Model.register("Glm4MoeForCausalLM")
+class Glm4MoeModel(Model):
     model_arch = gguf.MODEL_ARCH.GLM4_MOE

     def __init__(self, *args, **kwargs):
@@ -3960,7 +3960,7 @@ class Glm4MoeModel(TextModel):
         # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
         self.block_count = self.hparams["num_hidden_layers"] + 1
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

But then it bombs out with this:

python \
    convert_hf_to_gguf.py \
    --outtype bf16 \
    --split-max-size 50G \
    --outfile /mnt/raid/models/ubergarm/GLM-4.5-Air-GGUF/ \
    /mnt/raid/models/zai-org/GLM-4.5-Air/

INFO:hf-to-gguf:gguf: experts used count = 8
INFO:hf-to-gguf:gguf: file type = 24
INFO:hf-to-gguf:Set model tokenizer
WARNING:hf-to-gguf:

WARNING:hf-to-gguf:**************************************************************************************
WARNING:hf-to-gguf:** WARNING: The BPE pre-tokenizer was not recognized!
WARNING:hf-to-gguf:**          There are 2 possible reasons for this:
WARNING:hf-to-gguf:**          - the model has not been added to convert_hf_to_gguf_update.py yet
WARNING:hf-to-gguf:**          - the pre-tokenization config has changed upstream
WARNING:hf-to-gguf:**          Check your model files and convert_hf_to_gguf_update.py and update them accordingly.
WARNING:hf-to-gguf:** ref:     https://github.com/ggerganov/llama.cpp/pull/6920
WARNING:hf-to-gguf:**
WARNING:hf-to-gguf:** chkhsh:  a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2
WARNING:hf-to-gguf:**************************************************************************************
WARNING:hf-to-gguf:

Traceback (most recent call last):
  File "/home/w/projects/ik_llama.cpp/convert_hf_to_gguf.py", line 4578, in <module>
    main()
  File "/home/w/projects/ik_llama.cpp/convert_hf_to_gguf.py", line 4572, in main
    model_instance.write()
  File "/home/w/projects/ik_llama.cpp/convert_hf_to_gguf.py", line 431, in write
    self.prepare_metadata(vocab_only=False)
  File "/home/w/projects/ik_llama.cpp/convert_hf_to_gguf.py", line 417, in prepare_metadata
    self.set_vocab()
  File "/home/w/projects/ik_llama.cpp/convert_hf_to_gguf.py", line 3971, in set_vocab
    tokens, toktypes, tokpre = self.get_vocab_base()
                               ^^^^^^^^^^^^^^^^^^^^^
  File "/home/w/projects/ik_llama.cpp/convert_hf_to_gguf.py", line 512, in get_vocab_base
    tokpre = self.get_vocab_base_pre(tokenizer)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/w/projects/ik_llama.cpp/convert_hf_to_gguf.py", line 662, in get_vocab_base_pre
    raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()")
NotImplementedError: BPE pre-tokenizer was not recognized - update get_vocab_base_pre()

Probably just a few name changes to go still? My other option is to use mainline lcpp's convert python script then try to pick back up here. Thoughts?

Oh I'll check for a missing commit here: convert_hf_to_gguf_update.py

@Thireus
Copy link
Contributor Author

Thireus commented Aug 1, 2025

Probably just a few name changes to go still? My other option is to use mainline lcpp's convert python script then try to pick back up here. Thoughts?

Oh I'll check for a missing commit here: convert_hf_to_gguf_update.py

I recommend using the llama.cpp script for this. Which is what I did. So not too sure if I had ported all the right changes. There are no changes to the _update.py version and I don't know what this version of the script is used for.

@InfernalDread
Copy link

InfernalDread commented Aug 1, 2025

image

I think it has something to do with the convert_hf_to_gguf_update.py script as well, as I do not see any reference of GLM 4.5 in the original convert_hf_to_gguf.py script

@ubergarm
Copy link
Contributor

ubergarm commented Aug 1, 2025

I recommend using the llama.cpp script for this. Which is what I did. So not too sure if I had ported all the right changes. There are no changes to the _update.py version and I don't know what this version of the script is used for.

Got it, yes I am now using sammcj [email protected]:sammcj/llama.cpp.git glm-4-5@3d15c4a94 and converting Air currently. I have not added the convert logic for Hunyuan-A13B and previous THUDM dense model either, but only implemented the cpp code.

For this PR will likely want to either:

  1. drop the convert script as it is broken currently and focus on the cpp code then follow up later with another PR for it
  2. spend a little time now to fix it up and bring it along as it is already started

Just some thoughts. I'll keep going, and other good news this seems to be working now: GLM-4.5-IQ4_KSS.gguf 178.396 GiB (4.276 BPW)

@ubergarm
Copy link
Contributor

ubergarm commented Aug 1, 2025

@Thireus

Sorry for spamming you, hah.. I think this patch is enough to get Air to convert here with this script. Interestingly needed a checksum from an older GLM model hah... should Be able to save it as patch.diff and run git apply patch.diff and push the change if u like.

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index c2748822..b1ca9e1c 100644
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -618,6 +618,9 @@ class Model:
         if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
             # ref: https://huggingface.co/THUDM/glm-4-9b-chat
             res = "chatglm-bpe"
+        if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
+            # ref: https://huggingface.co/THUDM/glm-4-9b-hf
+            res = "glm4"
         if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
             # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5
             res = "gpt-2"
@@ -3951,8 +3954,8 @@ class Dots1Model(Qwen2MoeModel):
             return [(self.map_tensor_name(name), data_torch)]
         return super().modify_tensors(data_torch, name, bid)

-@ModelBase.register("Glm4MoeForCausalLM")
-class Glm4MoeModel(TextModel):
+@Model.register("Glm4MoeForCausalLM")
+class Glm4MoeModel(Model):
     model_arch = gguf.MODEL_ARCH.GLM4_MOE

     def __init__(self, *args, **kwargs):

Keep in mind I believe some things have changed in that draft PR and there could be a few other things. I'm gonna keep going with the convert Air from mainline for now as it is about 70% done.

@InfernalDread
Copy link

not sure if that "res" needs to be changed from "gpt-2" to "glm4"

@ubergarm
Copy link
Contributor

ubergarm commented Aug 1, 2025

I used mainline convert on Air and quantized a Q8_0 but got this trying to start up ik_llama.cpp llama-server:

llama_model_load: error loading model: done_getting_tensors: wrong number of tensors; expected 803, got 757

Unfortunately I'm out of time to play with this today. I may try using this PRs convert and look closer at the tensors when I have some more time.

src/llama.cpp Outdated
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}

// --- NextN / MTP tensors (preserved but unused), on the final layer ---
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is MTP on the roadmap for ik_llama?

Thireus added a commit to Thireus/ik_llama.cpp that referenced this pull request Aug 2, 2025
@ubergarm
Copy link
Contributor

ubergarm commented Aug 2, 2025

Okay, having some luck with GLM-4.5-Air now. Cooking imatrix off the bf16 and testing the Q8_0 on llama-server so far so good!

compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 899.288 ms
compute_imatrix: computing over 814 chunks with batch_size 512
compute_imatrix: 7.09 seconds per pass - ETA 1 hours 36.25 minutes
[1]18.4611,[2]7.6556,[3]4.9280,[4]3.5268,[5]2.8022,[6]2.3688,[7]2.1118,[8]1.9437,[9]1.9381,
save_imatrix: entry '             blk.46.ffn_down_exps.weight' has partial data (96.88%) 4 out of 128 experts are missing data Storing **but be aware**
save_imatrix: entry '               blk.46.ffn_up_exps.weight' has partial data (96.88%) 4 out of 128 experts are missing data Storing **but be aware**
save_imatrix: entry '             blk.46.ffn_gate_exps.weight' has partial data (96.88%) 4 out of 128 experts are missing data Storing **but be aware**

save_imatrix: stored collected data after 10 chunks in /mnt/raid/models/ubergarm/GLM-4.5-Air-GGUF/imatrix-GLM-4.5-Air-BF16.dat
[10]1.8529,[11]1.9772,[12]2.1353,[13]2.2048,[14]2.2779,
llm_load_print_meta: model type       = 106B.A12B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 109.194 B
llm_load_print_meta: model size       = 108.119 GiB (8.505 BPW)
llm_load_print_meta: repeating layers = 106.891 GiB (8.505 BPW, 107.952 B parameters)
llm_load_print_meta: general.name     = GLM 4.5 Air

I was thinking about what @TheLegendOfKitty said and I was like "psure MTP is not on the roadmap" so I just removed those MTP tensors and seems to be happy now. The other clue was on startup before this next patch here it was complaining with: llama_model_load: error loading model: done_getting_tensors: wrong number of tensors; expected 803, got 757. So 803-757 is 46 which psure is the number of repeating layers. So it was as if there were one tensor per layer missing. Thinking on it some more I was like "isn't there some need to set unused tensors with some kind of no-op?" but I don't recall so I just tried to remove adding those MTP tensors because I don't even know what MTP means.

So anyway I used this PR plus this patch on convert_hf_to_gguf.py and managed to convert GLM-4.5-Air and run it successfully now too. There might be a better way to go about it like is there a way to mark those extra tensors as noop so they don't cause a fuss? Also I added a comment because the mainline PR is using a different name ".ffn_gate_inp.bias" here is ".exp_probs_b" on mainline, so mainline couldn't run the Q8_0 throwing llama_model_load: error loading model: missing tensor 'blk.3.exp_probs_b' so I couldn't test it over there.

👈 Patch
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index c2748822..5da48e24 100644
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -618,6 +618,9 @@ class Model:
         if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
             # ref: https://huggingface.co/THUDM/glm-4-9b-chat
             res = "chatglm-bpe"
+        if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
+            # ref: https://huggingface.co/THUDM/glm-4-9b-hf
+            res = "glm4"
         if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
             # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5
             res = "gpt-2"
@@ -3951,8 +3954,8 @@ class Dots1Model(Qwen2MoeModel):
             return [(self.map_tensor_name(name), data_torch)]
         return super().modify_tensors(data_torch, name, bid)

-@ModelBase.register("Glm4MoeForCausalLM")
-class Glm4MoeModel(TextModel):
+@Model.register("Glm4MoeForCausalLM")
+class Glm4MoeModel(Model):
     model_arch = gguf.MODEL_ARCH.GLM4_MOE

     def __init__(self, *args, **kwargs):
@@ -4093,7 +4096,7 @@ class Glm4MoeModel(TextModel):
         # Handle expert gating input (routing gate)
         if ".mlp.gate.e_score_correction_bias" in name:
             new_name = name.replace("model.layers.", "blk.").replace(
-                ".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias"
+                ".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias" # *NOTE* this is ".exp_probs_b" in mainline PR
             )
             return [(new_name, data_torch)]
         elif ".mlp.gate.weight" in name:
@@ -4132,6 +4135,7 @@ class Glm4MoeModel(TextModel):
             return [(self.map_tensor_name(new_name), data_torch)]

         # Handle special NextN tensors - preserve for future MTP support
+        # What are these MTP tensors? If we preserve them do they need to be "noop" or whatever?
         if (
             ".embed_tokens." in name
             or ".shared_head." in name
@@ -4140,7 +4144,9 @@ class Glm4MoeModel(TextModel):
             or ".hnorm." in name
         ):
             new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "")
-            return [(new_name, data_torch)]
+            logger.debug(f"Skipping MTP tensor: {new_name}")
+            return []
+            #return [(new_name, data_torch)]

         # GLM tensor mapping - handle directly without map_tensor_name
         if ".input_layernorm." in name:

If this imatrix cooks I'll upload that at least and see about a ~4BPW quant or not depending on if the tensors are nice sizes etc.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 2, 2025

@ubergarm - Don't we want to preserve the MTP tensors for future MTP support? I see you've set return[]. Would the script work without this change?

@Thireus
Copy link
Contributor Author

Thireus commented Aug 2, 2025

@ubergarm - More info about MTP can be found here: ggml-org/llama.cpp#13236

Also discussed here: ggml-org/llama.cpp#14939 (comment)

@Thireus
Copy link
Contributor Author

Thireus commented Aug 10, 2025

Have you done comparisons with mainline using the llama-sweep-bench port lately? I was surprised to see mainline appears to be markedly faster at least in this hardware/model/quant configuration.

Oh wow, I definitely need to check this out! Thanks for the tip. I'll provide feedback after I do.

@usrlocalben - Kimi is a beast! Sadly quite slow and very big, but it's on my list of options too!

@saood06
Copy link
Collaborator

saood06 commented Aug 10, 2025

IQ2 😰

Which IQ2?

@usrlocalben
Copy link
Contributor

usrlocalben commented Aug 10, 2025

IQ2 😰

Which IQ2?

ubergarm IQ2_KL

added intel autoround Q2 above as well.

@Thireus K2 is much faster than R1. fewer active params and no-think. much worse vram rent though.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 10, 2025

Have you done comparisons with mainline using the llama-sweep-bench port lately? I was surprised to see mainline appears to be markedly faster at least in this hardware/model/quant configuration.

I can confirm llama.cpp is more than 2x faster than ik_llama.cpp on the Blackwell card. :|

I might switch back to llama.cpp then. 😂

See details: https://huggingface.co/ubergarm/GLM-4.5-Air-GGUF/discussions/2#689916c972f9ea17ff768e4b

@saood06
Copy link
Collaborator

saood06 commented Aug 10, 2025

IQ2 😰

Which IQ2?

ubergarm IQ2_KL

added intel autoround Q2 above as well.

Thanks for clarifying and adding the extra data point, that looks like an impressive result for the autoround compared to the ubergarm IQ2_KL

@Thireus K2 is much faster than R1. fewer active params and no-think. much worse vram rent though.

I'm not @Thireus but I haven't tried K2 even though it sounds tempting for that reason because my R1/V3 quants leave me no spare room at ~4.5 BPW, and so either I go to a lower BPW or RPC and neither sounded compelling enough so far.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 12, 2025

Very interesting... GLM-4.5-Air was able to solve the Dipiloblop prompt with mainline llama.cpp using the following recipe:

## Quant mix recipe created using Thireus' GGUF Tool Suite - https://gguf.thireus.com/
# Model name: GLM-4.5-Air
# Link to the original model: https://huggingface.co/zai-org/GLM-4.5-Air

## Model head & embeddings — qbits: 32 8 5 
output_norm\.weight=f32
token_embd\.weight=q5_K
output\.weight=q8_0

## Multi-headed attention parameters — qbits: 32 4 
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_k\.bias=f32
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_output\.weight=iq4_xs
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_k\.weight=iq4_xs
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_q\.weight=iq4_xs
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_q\.bias=f32
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_v\.bias=f32
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_v\.weight=iq4_xs
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_norm\.weight=f32

## Core FFN weights — qbits: 32 8 
blk\.0\.ffn_gate\.weight=q8_0
blk\.([1-9]|[1-3][0-9]|4[0-6])\.ffn_gate_inp\.weight=f32
blk\.0\.ffn_down\.weight=q8_0
blk\.0\.ffn_up\.weight=q8_0

## Other tensors — qbits: 32 4 
blk\.([0-9]|[1-3][0-9]|4[0-6])\.post_attention_norm\.weight=f32
blk\.46\.nextn\.shared_head_head\.weight=iq4_xs
blk\.46\.nextn\.embed_tokens\.weight=iq4_xs
blk\.46\.nextn\.shared_head_norm\.weight=f32
blk\.([1-9]|[1-3][0-9]|4[0-6])\.exp_probs_b\.bias=f32
blk\.46\.nextn\.enorm\.weight=f32
blk\.46\.nextn\.hnorm\.weight=f32
blk\.46\.nextn\.eh_proj\.weight=iq4_xs

## GPU-loaded ffn_*_shexp
# ffn_down_shexp (down-projection) — qbits: 4 
blk\.([1-9]|[1-3][0-9]|4[0-6])\.ffn_down_shexp\.weight=iq4_xs

# ffn_up_shexp (up-projection) — qbits: 8 6 5 
blk\.(1|[3-4]|[7-8]|17|19|30|32|38|2[3-5]|4[0-5]|2[8-9]|2[0-1]|1[2-3]|3[4-5])\.ffn_up_shexp\.weight=q8_0
blk\.([5-6]|22|31|33|39|46|[2-3][6-7]|1[5-6]|1[0-1])\.ffn_up_shexp\.weight=q6_K
blk\.(2|9|14|18)\.ffn_up_shexp\.weight=q5_K

# ffn_gate_shexp (gate-projection) — qbits: 8 6 5 
blk\.([1-3]|[6-9]|19|20|30|33|4[0-5]|1[0-2]|2[3-9])\.ffn_gate_shexp\.weight=q8_0
blk\.(4|13|18|46|[2-3][1-2]|3[4-9]|1[5-6])\.ffn_gate_shexp\.weight=q6_K
blk\.(5|14|17)\.ffn_gate_shexp\.weight=q5_K

## CPU-loaded ffn_*_exps
# ffn_down_exps (down-extraction) — qbits: 4 
blk\.([1-9]|[1-3][0-9]|4[0-6])\.ffn_down_exps\.weight=iq4_nl

# ffn_up_exps (up-extraction) — qbits: 5 4 3 2 
blk\.46\.ffn_up_exps\.weight=q5_K
blk\.(25|37|39|3[1-2]|2[8-9]|3[4-5])\.ffn_up_exps\.weight=iq4_xs
blk\.(3|[5-6]|10|12|15|18|30|33|36|38|4[0-5]|2[6-7]|2[0-4])\.ffn_up_exps\.weight=iq3_s
blk\.([1-2]|4|[7-9]|11|19|1[6-7]|1[3-4])\.ffn_up_exps\.weight=q2_K

# ffn_gate_exps (gate-extraction) — qbits: 5 4 3 2 
blk\.46\.ffn_gate_exps\.weight=q5_K
blk\.(10|2[8-9]|3[1-4])\.ffn_gate_exps\.weight=iq4_xs
blk\.(3|6|30|4[0-5]|2[0-7]|3[5-9]|1[8-9]|1[3-6])\.ffn_gate_exps\.weight=iq3_s
blk\.([1-2]|[4-5]|[7-9]|17|1[1-2])\.ffn_gate_exps\.weight=q2_K

## Summary of tensor sizes per class
# GPU Total: 5.015 GiB (94.9%) | 5.29 GiB max, if all were q8_0 | 4.64 GiB min, if all were q5_K
# CPU Total: 44.902 GiB (87.0%) | 51.61 GiB max, if all were iq4_xs | 39.04 GiB min, if all were q2_K
# GPU+CPU Total: 49.917 GiB (90.9%)

## Summary of tensor counts and bpw per qtype
#
# GPU-loaded quants:
# QTYPE		Count	BPW	Assigned GiB	% Assigned	Max GiB (all)
# +f32       	331	32.0  	  0.09 GiB	-		-
# +q8_0      	1  	8.5   	  0.04 GiB	-		-
# q8_0      	57 	8.5   	  1.01 GiB	54.9%		1.84
# q6_K      	31 	6     	  0.14 GiB	9.6%		1.42
# q5_K      	8  	5     	  0.42 GiB	35.5%		1.19
# +iq4_xs    	237	4.25  	  3.31 GiB	-		-
#
# CPU-loaded quants:
# QTYPE		Count	BPW	Assigned GiB	% Assigned	Max GiB (all)
# +q5_K      	2  	5     	  0.95 GiB	-		-
# +iq4_nl    	46 	4.5   	 17.79 GiB	-		-
# iq4_xs    	16 	4.25  	  5.84 GiB	17.8%		32.87
# iq3_s     	52 	3.4375	 15.36 GiB	57.8%		26.59
# q2_K      	22 	2     	  4.96 GiB	24.4%		20.30
#
# -Average BPW: 3.7031
#
# -Notes:
# - '+' means user-defined pre-assigned tensors, or tensor missing from csv data or f32 tensors
# - Recipe produced on the 2025-08-12 05:39:33 UTC+0000 using Thireus' GGUF tools (https://gguf.thireus.com/)
# - Script SHA-256: a02563df96ccec6c78ab7c716771153ae5f5ef4e9ee6a04d372f273eb1662e9c
# - Calibration dataset 'ppl_results.csv' SHA-256: c596235f01c582988d23f97e1e6809a83923ae3f5321e3cde00625c9c92952f3
# - tensors.bf16.map SHA-256: f440313db9b7ce593240c0b0acb723182ee3ae9570eca868dc6eb440112fdd67
# - tensors.bf16.map model name: GLM-4.5-Air-THIREUS-BF16-SPECIAL_TENSOR-00804-of-00804
# - tensors.iq4_xs.map SHA-256: 28c799175c45409d6f59d609e82c5f0ed2bba3240b7c5697afbdc76824b1b046
# - tensors.iq4_xs.map model name: GLM-4.5-Air-THIREUS-IQ4_XS-SPECIAL_TENSOR-00804-of-00804
# - tensors.iq3_s.map SHA-256: c728289eeab5a07292bbaec6db8949ba9ddcf63e6cc337b0f15565755769e115
# - tensors.iq3_s.map model name: GLM-4.5-Air-THIREUS-IQ3_S-SPECIAL_TENSOR-00804-of-00804
# - tensors.q2_K.map SHA-256: 35835d7d56a8161d07c31c73f42008cba7c7e2025d4bb60be2a82b254497f183
# - tensors.q2_K.map model name: GLM-4.5-Air-THIREUS-Q2_K-SPECIAL_TENSOR-00804-of-00804
# - tensors.q8_0.map SHA-256: c00093e70a6c32aab72b404457c12a7b238b0e030975267d93d2b09a30796151
# - tensors.q8_0.map model name: GLM-4.5-Air-THIREUS-Q8_0-SPECIAL_TENSOR-00804-of-00804
# - tensors.q5_K.map SHA-256: b60aadca788055846a572cad5121e1d93bfa9bbbd520ae6350c84f52319f945f
# - tensors.q5_K.map model name: GLM-4.5-Air-THIREUS-Q5_K-SPECIAL_TENSOR-00804-of-00804
# - tensors.q6_K.map SHA-256: 5165939ae192b9008b49432f574da6df0a8df9989faf337cc3a062d04f80aef2
# - tensors.q6_K.map model name: GLM-4.5-Air-THIREUS-Q6_K-SPECIAL_TENSOR-00804-of-00804
# - tensors.iq2_ks.map SHA-256: efed8f3d7d712a6ad99c5904f6e2f4b89387cc78e4008d9ca557bd04da1f2b31
# - tensors.iq2_ks.map model name: GLM-4.5-Air-THIREUS-IQ2_KS-SPECIAL_TENSOR-00804-of-00804
# - tensors.iq4_nl.map SHA-256: ed9419ab49fc319d033148d4241f28fcafb1d3089cb3172da9fa5f62c978a93d
# - tensors.iq4_nl.map model name: GLM-4.5-Air-THIREUS-IQ4_NL-SPECIAL_TENSOR-00804-of-00804
# - GPG signatures: PASSED
# - Command used:
# ../../quant_assign.py ppl_results.csv --tolerance 0.01 --cpu-irq-k 1.5 --gpu-irq-k 1.5 --gpu-assign-qtype iq4_xs \
# --cpu-tensors-max-size 45 --gpu-tensors-max-size 95% --exponential-factor 8 --cpu-tensors \
# 'blk\.([1-9]|[1-3][0-9]|4[0-5])\.ffn_up_exps\.weight' 'blk\.([1-9]|[1-3][0-9]|4[0-5])\.ffn_gate_exps\.weight' \
# --gpu-tensors '.*' --cpu-quants iq4_xs iq3_s q2_K --gpu-quants q8_0 q5_K q6_K --cpu-assign-tensors \
# 'blk\.(46)\.ffn_up_exps\.weight=q5_K' 'blk\.(46)\.ffn_gate_exps\.weight=q5_K' \
# 'blk\.([1-9]|[1-3][0-9]|4[0-6])\.ffn_down_exps\.weight=iq4_nl' --gpu-assign-tensors \
# 'blk\.(0)\.ffn_down\.weight=q8_0'

## THE END!

Model answer:

<think>Let's analyze the question: "In dipiloblop, what level are you at when you get to half the xp of max level?"

From the provided book, I need to look for information about experience and leveling. I should find the experience table or any information about how experience relates to levels.

Looking through the text, I found a "Skills/Experience Table" section that shows experience needed for each level. Let me extract this information:

\`\`\`
Level 	Experience needed to reach level 	Experience difference from last level
1	0	0
2	266	266
3	553	287
4	888	335
5	1,264	376
6	1,697	433
7	2,202	505
8	2,761	559
9	3,421	660
10	4,162	741
11	5,009	847
12	5,987	978
13	7,090	1,103
14	8,339	1,249
15	9,792	1,453
16	11,423	1,631
17	13,273	1,850
18	15,397	2,124
19	17,798	2,401
20	20,537	2,739
21	23,635	3,098
22	27,183	3,548
23	31,160	3,977
24	35,720	4,560
25	40,895	5,175
26	46,750	5,855
27	53,376	6,626
28	60,943	7,567
29	69,493	8,550
30	79,195	9,702
31	90,198	11,003
32	102,691	12,493
33	116,838	14,147
34	132,892	16,054
35	151,085	18,193
36	171,731	20,646
37	195,093	23,362
38	221,634	26,541
39	251,703	30,069
40	285,768	34,065
41	324,379	38,611
42	368,189	43,810
43	417,795	49,606
44	474,049	56,254
45	537,848	63,799
46	610,073	72,225
47	691,990	81,917
48	784,837	92,847
49	890,059	105,222
50	1,009,301	119,242
51	1,144,480	135,179
52	1,297,644	153,164
53	1,471,284	173,640
54	1,667,999	196,715
55	1,891,034	223,035
56	2,143,788	252,754
57	2,430,212	286,424
58	2,754,802	324,590
59	3,122,772	367,970
60	3,539,721	416,949
61	4,012,280	472,559
62	4,547,846	535,566
63	5,154,809	606,963
64	5,842,739	687,930
65	6,622,340	779,601
66	7,505,874	883,534
67	8,507,246	1,001,372
68	9,642,117	1,134,871
69	10,928,297	1,286,180
70	12,385,868	1,457,571
71	14,037,935	1,652,067
72	15,910,153	1,872,218
73	18,031,971	2,121,818
74	20,436,698	2,404,727
75	23,162,089	2,725,391
76	26,250,771	3,088,682
77	29,751,292	3,500,521
78	33,718,506	3,967,214
79	38,214,716	4,496,210
80	43,310,285	5,095,569
81	49,085,275	5,774,990
82	55,630,259	6,544,984
83	63,047,786	7,417,527
84	71,454,370	8,406,584
85	80,981,751	9,527,381
86	91,779,442	10,797,691
87	104,016,620	12,237,178
88	117,885,480	13,868,860
89	133,603,420	15,717,940
90	151,417,175	17,813,755
91	171,605,798	20,188,623
92	194,486,167	22,880,369
93	220,417,391	25,931,224
94	249,805,802	29,388,411
95	283,112,599	33,306,797
96	320,860,372	37,747,773
97	363,640,890	42,780,518
98	412,125,351	48,484,461
99	467,074,482	54,949,131
100	507,038,931	48,364,282
101	556,387,432	49,348,501
\`\`\`

To answer the question, I need to determine what level corresponds to having half the experience needed to reach the maximum level.

From the table, I can see that:
- Level 100 requires 507,038,931 experience points
- Level 101 requires 556,387,432 experience points

The maximum level shown is 101, which requires 556,387,432 experience points.

Half of the experience needed to reach level 101 is 556,387,432 ÷ 2 = 278,193,716 experience points.

Looking at the table, I need to find which level requires closest to 278,193,716 experience points.

From the table:
- Level 75 requires 23,162,089 experience points
- Level 76 requires 26,250,771 experience points
- Level 77 requires 29,751,292 experience points
- Level 78 requires 33,718,506 experience points
- Level 79 requires 38,214,716 experience points
- Level 80 requires 43,310,285 experience points
- Level 81 requires 49,085,275 experience points
- Level 82 requires 55,630,259 experience points
- Level 83 requires 63,047,786 experience points
- Level 84 requires 71,454,370 experience points
- Level 85 requires 80,981,751 experience points
- Level 86 requires 91,779,442 experience points
- Level 87 requires 104,016,620 experience points
- Level 88 requires 117,885,480 experience points
- Level 89 requires 133,603,420 experience points
- Level 90 requires 151,417,175 experience points
- Level 91 requires 171,605,798 experience points
- Level 92 requires 194,486,167 experience points
- Level 93 requires 220,417,391 experience points
- Level 94 requires 249,805,802 experience points
- Level 95 requires 283,112,599 experience points

Half of the max experience (278,193,716) falls between level 94 (249,805,802) and level 95 (283,112,599).

To be more precise, I can calculate how far we are from level 94 to level 95:
- Level 94 to 95 difference: 283,112,599 - 249,805,802 = 33,306,797 experience points
- How much we need beyond level 94: 278,193,716 - 249,805,802 = 28,387,914 experience points
- Percentage of the way: 28,387,914 ÷ 33,306,797 ≈ 0.852 or 85.2%

So, you would be at approximately 85.2% of the way from level 94 to level 95 when you have half the experience needed to reach level 101.

However, the question asks what level you're at, not what percentage. Since you don't reach level 95 until you have 283,112,599 experience points, and 278,193,716 is less than that, you would still be level 94 when you have half the experience needed to reach the maximum level of 101.

Actually, let me double-check this. The question is asking what level you're at when you have half the experience needed to reach the maximum level. The maximum level shown in the table is 101, which requires 556,387,432 experience points.

Half of that is 278,193,716 experience points.

Looking at the table:
- Level 94 requires 249,805,802 experience points
- Level 95 requires 283,112,599 experience points

Since 278,193,716 is greater than 249,805,802 but less than 283,112,599, you would be level 94 when you have half the experience needed to reach level 101.

Wait, I think I need to be more precise. In most RPGs like this, you're considered a certain level once you've accumulated enough experience to reach that level, not when you're partway through it. So, when you have 278,193,716 experience points, you've accumulated enough to reach level 94, but not enough to reach level 95. Therefore, you would be considered level 94.

So the answer is that you're at level 94 when you have half the experience needed to reach the maximum level of 101.</think>Based on the experience table provided in the book, I can determine what level you're at when you have half the experience needed to reach the maximum level.

The maximum level shown in the table is level 101, which requires 556,387,432 experience points to reach.

Half of this experience requirement is:
556,387,432 ÷ 2 = 278,193,716 experience points

Looking at the experience table:
- Level 94 requires 249,805,802 experience points
- Level 95 requires 283,112,599 experience points

Since 278,193,716 is greater than the requirement for level 94 but less than the requirement for level 95, you would be at level 94 when you have half the experience needed to reach the maximum level of 101.
Prompt
- Tokens: 104293
- Time: 115899.915 ms
- Speed: 899.9 t/s
Generation
- Tokens: 2762
- Time: 139554.456 ms
- Speed: 19.8 t/s

@Thireus
Copy link
Contributor Author

Thireus commented Aug 12, 2025

I'm positive. Same model, different results: ik_llama.cpp is not able to solve it, while llama.cpp solves it. Something is up with our implementation.

@ikawrakow
Copy link
Owner

@Thireus

Unless you are using a zero-temperature setting, to make a definitive statement you need to run the task several times and count how many times the task was solved.

But given the performance results, I wouldn't be surprised if the implementation in ik_llama.cpp is not 100% correct. I'm in the middle of implementing GPT-OSS, so don't want to get distracted with other stuff, but when I'm done, I'll try to look more closely into GLM-4.5.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 12, 2025

@ikawrakow, you're right - need a more scientific testing methodology. I have decided to set the temps and related settings the same between ik_llama.cpp and llama.cpp and do 10 rounds for each. Here are the results:

ik_llama.cpp:

CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 ~/ik_llama-main-b4074-62ef02e-bin-win-cuda-12.8-x64-avx512/llama-server -m GLM-4.5-Air-THIREUS-BF16-SPECIAL_TENSOR-00001-of-00804.gguf -fa   -ctk f16   -ngl 99   -b 4096 -ub 4096   --threads 36   --main-gpu 0 --no-mmap --temp 0.6 --top-k 40 --top-p 0.95 --min-p 0.05
  1. WRONG, BUT CORRECT... Thinks max level is 99.
  2. WRONG. Numbers pulled from the XP table are mixed up.
  3. ALMOST. Considered level 95 instead.
  4. CORRECT!
  5. WRONG, BUT CORRECT... Numbers pulled from the XP table are mixed up, but answer is correct.
  6. WRONG, BUT CORRECT... Numbers pulled from the XP table are mixed up, but answer is correct.
  7. CORRECT!
  8. CORRECT!
  9. WRONG, completely wrong, model is hallucinating. Numbers pulled from the XP table are invented.
  10. WRONG. Numbers pulled from the XP table are very mixed up.

llama.cpp:

$ CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=1 ~/llama-b6146-bin-win-cuda-12.8-x64/llama-server -m GLM-4.5-Air-THIREUS-BF16-SPECIAL_TENSOR-00001-of-00804.gguf -fa   -ctk f16   -c 131072   -ngl 99   -b 4096 -ub 4096   --no-mmap   --threads 36   --main-gpu 0 --temp 0.6 --top-k 40 --top-p 0.95 --min-p 0.05 --port 8081
  1. WRONG. Numbers pulled from the XP table are mixed up.
  2. CORRECT!
  3. WRONG. Numbers pulled from the XP table are mixed up.
  4. WRONG, BUT CORRECT... Numbers pulled from the XP table are mixed up, but answer is correct.
  5. CORRECT!
  6. CORRECT!
  7. CORRECT!
  8. CORRECT!
  9. CORRECT!
  10. CORRECT!

So, it looks like ik_llama.cpp is somehow unable to pull the correct XP for level 101 most of the time.

The same model recipe was used for both:

## Model head & embeddings
output_norm\.weight=f32
token_embd\.weight=q5_K
output\.weight=q8_0

## Multi-headed attention parameters
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_k\.bias=f32
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_output\.weight=iq4_xs
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_k\.weight=iq4_xs
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_q\.weight=iq4_xs
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_q\.bias=f32
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_v\.bias=f32
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_v\.weight=iq4_xs
blk\.([0-9]|[1-3][0-9]|4[0-6])\.attn_norm\.weight=f32

## Core FFN weights
blk\.0\.ffn_gate\.weight=iq3_xxs
blk\.([1-9]|[1-3][0-9]|4[0-6])\.ffn_gate_inp\.weight=f32
blk\.0\.ffn_down\.weight=iq3_xxs
blk\.0\.ffn_up\.weight=iq3_xxs

## Other tensors
blk\.([0-9]|[1-3][0-9]|4[0-6])\.post_attention_norm\.weight=f32
blk\.46\.nextn\.shared_head_head\.weight=iq4_xs
blk\.46\.nextn\.embed_tokens\.weight=iq4_xs
blk\.46\.nextn\.shared_head_norm\.weight=f32
blk\.([1-9]|[1-3][0-9]|4[0-6])\.exp_probs_b\.bias=f32
blk\.46\.nextn\.enorm\.weight=f32
blk\.46\.nextn\.hnorm\.weight=f32
blk\.46\.nextn\.eh_proj\.weight=iq4_xs

## GPU-loaded ffn_*_shexp
# ffn_down_shexp (down-projection) — qbits: 4 
blk\..*\.ffn_down_shexp\.weight=iq3_xxs

# ffn_up_shexp (up-projection)
blk\..*\.ffn_up_shexp\.weight=iq3_xxs

# ffn_gate_shexp (gate-projection)
blk\..*\.ffn_gate_shexp\.weight=iq3_xxs

## CPU-loaded ffn_*_exps
# ffn_down_exps (down-extraction)
blk\..*\.ffn_down_exps\.weight=iq3_xxs

# ffn_up_exps (up-extraction)
blk\..*\.ffn_up_exps\.weight=iq3_xxs

# ffn_gate_exps (gate-extraction)
blk\..*\.ffn_gate_exps\.weight=iq3_xxs

ik_vs_llama.zip

@trilog-inc
Copy link

One thing to note, for long context -b 4096 and -ub 4096 breaks the model. Same with mainline. Maybe you should use -b 2048 and -ub 2048

ggml-org/llama.cpp#14939 (comment)

@Thireus
Copy link
Contributor Author

Thireus commented Aug 12, 2025

One thing to note, for long context -b 4096 and -ub 4096 breaks the model. Same with mainline. Maybe you should use -b 2048 and -ub 2048

ggml-org/llama.cpp#14939 (comment)

Ah! I'll give it a go thanks. Interestingly I have not noticed any issue with llama.cpp.

Edit: Nope that's not it, same issue.
Edit2: omg, -n 512 was in the command line. Alright, giving it all another shot.
Edit3: Previous results updated. Outcome remains that something is up with the ik_llama.cpp implementation (or I'm doing something wrong, which is entirely possible).

@saood06
Copy link
Collaborator

saood06 commented Aug 13, 2025

ik_llama.cpp:

CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 ~/ik_llama-main-b4074-62ef02e-bin-win-cuda-12.8-x64-avx512/llama-server -m GLM-4.5-Air-THIREUS-BF16-SPECIAL_TENSOR-00001-of-00804.gguf -fa   -ctk f16   -ngl 99   -b 4096 -ub 4096   --threads 36   --main-gpu 0 -p 8192 --no-mmap --temp 0.6 --top-k 40 --top-p 0.95 --min-p 0.05

[...]
llama.cpp:

$ CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=1 ~/llama-b6146-bin-win-cuda-12.8-x64/llama-server -m GLM-4.5-Air-THIREUS-BF16-SPECIAL_TENSOR-00001-of-00804.gguf -fa   -ctk f16   -c 131072   -ngl 99   -b 4096 -ub 4096   --no-mmap   --threads 36   --main-gpu 0 --temp 0.6 --top-k 40 --top-p 0.95 --min-p 0.05 --port 8081

Why is -p 8192 set for ik_llama? It is not set for llama.cpp

@Thireus
Copy link
Contributor Author

Thireus commented Aug 13, 2025

Why is -p 8192 set for ik_llama? It is not set for llama.cpp

Lack of sleep. I ran everything again and updated my results in this post: #668 (comment)

Same outcome as before, I believe this parameter didn't have any effect.

@saood06
Copy link
Collaborator

saood06 commented Aug 13, 2025

I believe this parameter didn't have any effect

I thought it set a global system_prompt which would prefix the context which could affect results.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 13, 2025

I thought it set a global system_prompt which would prefix the context which could affect results.

You were right to point this out, I think this is what it does indeed and probably didn't make the previous results any better. The updated results are without this -p option.

Would be good if someone could replicate this test without necessarily using my parameters but their own to validate if something is up with our implementation. Could be RoPE, could be something else... could also be extreme good luck I had with 10 good seeds on llama.cpp... but I doubt it.

@ikawrakow
Copy link
Owner

WRONG, BUT CORRECT... Numbers pulled from the XP table are mixed up, but answer is correct.

So, if we count this as correct, then llama.cpp gets 9/10, ik_llama.cpp gets 6/10. Let's assume for a moment that there isn't a real difference between the two, and let's assume that the actual fraction of correct answers is 7.5/10 (average of llama.cpp and ik_llama.cpp). With these assumptions, a sample of 10 random draws has a statistical uncertainty of 7.5*sqrt(0.25*0.75/9) = 3.2, and hence the difference is not really statistically significant.

But if there is a real difference, the fact that llama.cpp mixes up the table more often may indicate an issue with tokenization and/or handling of special tokens.

@trilog-inc
Copy link

Another update from mainline, @ajunca narrowed down some of the large context issues to a CUDA graph issue. Could be interesting to run your test with a compiled version with GGML_CUDA_USE_GRAPHS=OFF

ggml-org/llama.cpp#14939 (comment)

@ikawrakow
Copy link
Owner

ikawrakow commented Aug 13, 2025

CUDA graphs are turned off for MoE models in ik_llama.cpp.

Oh, and the huge PPL observed for GPT-OSS is there with or without CUDA graphs. Here and in mainline.

@ajunca
Copy link

ajunca commented Aug 13, 2025

Another update from mainline, @ajunca narrowed down some of the large context issues to a CUDA graph issue. Could be interesting to run your test with a compiled version with GGML_CUDA_USE_GRAPHS=OFF

ggml-org/llama.cpp#14939 (comment)

No, at the end was not a problem of cuda graphs... After more testing, the problem was still there. I think though, is a problem of CUDA device synchronization. And that there are two problems related:

  • One is the gibberish problem (eg. GGGGGGGGG...), that at the end if you wait long enough (sometimes is very fast, some times it takes longer), you get a cuda crash like:
/app/ggml/src/ggml-cuda/ggml-cuda.cu:84: CUDA error
CUDA error: an illegal memory access was encountered
[0m  current device: 1, in function ggml_backend_cuda_synchronize at /app/ggml/src/ggml-cuda/ggml-cuda.cu:2612
[0m  cudaStreamSynchronize(cuda_ctx->stream())
[0mlibggml-base.so(+0x16d4b)[0x7fb29ab44d4b]
libggml-base.so(ggml_print_backtrace+0x21f)[0x7fb29ab451af]
libggml-base.so(ggml_abort+0x152)[0x7fb29ab45382]
/app/libggml-cuda.so(+0xea826)[0x7fb292aea826]
/app/libggml-cuda.so(+0xebaf3)[0x7fb292aebaf3]
libggml-base.so(ggml_backend_sched_synchronize+0x2e)[0x7fb29ab5aece]
libllama.so(_ZN13llama_context11synchronizeEv+0x14)[0x7fb29a897ac4]
libllama.so(llama_get_logits_ith+0x16)[0x7fb29a899de6]
/app/llama-server(+0x1f00fe)[0x55d97e53e0fe]
/app/llama-server(+0xdf547)[0x55d97e42d547]
/app/llama-server(+0x857bd)[0x55d97e3d37bd]
/app/llama-server(+0x4d6a5)[0x55d97e39b6a5]
/usr/lib/x86_64-linux-gnu/libc.so.6(+0x29d90)[0x7fb29a029d90]
/usr/lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0x80)[0x7fb29a029e40]
/app/llama-server(+0x4f0f5)[0x55d97e39d0f5]
  • And second issue is that some times, during prompt processing it gets kind of stuck/freeze, and does not advance more, lets say it stucks at 0.84. If you look at GPU usage, you will see one idle, and the other 100%. It also smells a bit as a synchronization issue. This is why I think they are connected. Example when pp gets stuck:
Screenshot_20250813_143915

If you go to #ggml-org/llama.cpp#15112 you will see things I tried.

Probably I will open a new issue, with all the stuff.

@ikawrakow
Copy link
Owner

@ajunca The problems you are describing are about mainline llama.cpp, not ik_llama.cpp, correct?

@trilog-inc
Copy link

Apologies if i brought in a different issue. I can report that setting -b 4096 and -ub 4096 affects ik_llama the same way as ajunca has described in mainline. Thought it could be good to reference their investigation if they are related, maybe they arent

@ajunca
Copy link

ajunca commented Aug 13, 2025

@ajunca The problems you are describing are about mainline llama.cpp, not ik_llama.cpp, correct?

Yes indead, are related to llama.cpp. I tried also ik_llama.cpp at some point, I remember I had problems, but don't remember which ones. Does ik_llama.cpp has a Vulkan backend? In llama.cpp, the Vulkan backend is workings (but much slower).

@ikawrakow
Copy link
Owner

In llama.cpp, the Vulkan backend is workings (but much slower).

The much slower part is no longer true.

@Thireus
Copy link
Contributor Author

Thireus commented Aug 13, 2025

One is the gibberish problem (eg. GGGGGGGGG...), that at the end if you wait long enough (sometimes is very fast, some times it takes longer), you get a cuda crash like:

I can confirm the GGGGGGG problem occurred to me as well when using the IQ3_K quant for some reason - https://huggingface.co/Thireus/Qwen3-Coder-480B-A35B-Instruct-THIREUS-IQ3_K-SPECIAL_SPLIT/discussions/1#6898de3d70c4fed919a7e7e9 - I thought it had something to do with the quantization of the model for that qtype being wrong... but seems like it may not be it. I've avoided using this quant and haven't been able to reproduce the issue with other quants since.

@ajunca
Copy link

ajunca commented Aug 13, 2025

In llama.cpp, the Vulkan backend is workings (but much slower).

The much slower part is no longer true.

I mean that when running with the Vulkan backend is not producing the mentioned problem. So looks like a CUDA problem.

Well just tried today. I get like 60% speed for generation (not bad at all), but only like 15% for prompt processing, so yea, quite slower on this regard.

@ajunca
Copy link

ajunca commented Aug 13, 2025

One is the gibberish problem (eg. GGGGGGGGG...), that at the end if you wait long enough (sometimes is very fast, some times it takes longer), you get a cuda crash like:

I can confirm the GGGGGGG problem occurred to me as well when using the IQ3_K quant for some reason - https://huggingface.co/Thireus/Qwen3-Coder-480B-A35B-Instruct-THIREUS-IQ3_K-SPECIAL_SPLIT/discussions/1#6898de3d70c4fed919a7e7e9 - I thought it had something to do with the quantization of the model for that qtype being wrong... but seems like it may not be it.

I tried many different, even MXFP4, and at some point and with enough testing, all seems to produce it.

@ikawrakow
Copy link
Owner

Well just tried today. I get like 60% speed for generation (not bad at all), but only like 15% for prompt processing, so yea, quite slower on this regard.

15% of what? 15% of mainline llama.cpp, 15% of the speed on the same system using the CUDA back-end? Something else?

@ajunca
Copy link

ajunca commented Aug 13, 2025

Well just tried today. I get like 60% speed for generation (not bad at all), but only like 15% for prompt processing, so yea, quite slower on this regard.

15% of what? 15% of mainline llama.cpp, 15% of the speed on the same system using the CUDA back-end? Something else?

Yea sorry, I was comparing the Vulkan vs Cuda in llama.cpp. Saying that in Vulkan it seems to work correctly, the only problem is that is considerably slower than the CUDA backend. (percentage compared to CUDA toks/s).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.