Skip to content

Commit db8dc06

Browse files
Merge pull request #52 from menloresearch/update-dev-from-master-2025-04-11-00-08
Sync master with upstream release b5106
2 parents 4969c29 + 47ba87d commit db8dc06

33 files changed

+1391
-693
lines changed

.devops/cuda.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ COPY . .
2121
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
2222
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
2323
fi && \
24-
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
24+
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
2525
cmake --build build --config Release -j$(nproc)
2626

2727
RUN mkdir -p /app/lib && \

.devops/intel.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
1717
&& export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \
1818
fi && \
1919
echo "Building with dynamic libs" && \
20-
cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16} && \
20+
cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${OPT_SYCL_F16} && \
2121
cmake --build build --config Release -j$(nproc)
2222

2323
RUN mkdir -p /app/lib && \

.devops/musa.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ COPY . .
3535
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
3636
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
3737
fi && \
38-
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
38+
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
3939
cmake --build build --config Release -j$(nproc)
4040

4141
RUN mkdir -p /app/lib && \

.devops/rocm.Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
1717
# gfx906 is deprecated
1818
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.4/reference/system-requirements.html
1919

20-
#ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102'
21-
ARG ROCM_DOCKER_ARCH=gfx1100
20+
ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102'
21+
#ARG ROCM_DOCKER_ARCH=gfx1100
2222

2323
# Set nvcc architectured
2424
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
@@ -40,7 +40,7 @@ WORKDIR /app
4040
COPY . .
4141

4242
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
43-
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DCMAKE_BUILD_TYPE=Release \
43+
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON \
4444
&& cmake --build build --config Release -j$(nproc)
4545

4646
RUN mkdir -p /app/lib \

.devops/vulkan.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ WORKDIR /app
1616

1717
COPY . .
1818

19-
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_CURL=1 && \
19+
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_CURL=1 -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
2020
cmake --build build --config Release -j$(nproc)
2121

2222
RUN mkdir -p /app/lib && \

convert_hf_to_gguf.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class Model:
6565
model_name: str | None
6666
metadata_override: Path | None
6767
dir_model_card: Path
68+
remote_hf_model_id: str | None
6869

6970
# subclasses should define this!
7071
model_arch: gguf.MODEL_ARCH
@@ -73,7 +74,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
7374
use_temp_file: bool = False, eager: bool = False,
7475
metadata_override: Path | None = None, model_name: str | None = None,
7576
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
76-
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
77+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
7778
if type(self) is Model:
7879
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
7980

@@ -83,11 +84,24 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
8384
self.is_big_endian = is_big_endian
8485
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
8586
self.use_temp_file = use_temp_file
86-
self.lazy = not eager
87-
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
88-
self.is_safetensors = len(self.part_names) > 0
89-
if not self.is_safetensors:
90-
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
87+
self.lazy = not eager or (remote_hf_model_id is not None)
88+
self.remote_hf_model_id = remote_hf_model_id
89+
if remote_hf_model_id is not None:
90+
self.is_safetensors = True
91+
92+
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
93+
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
94+
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
95+
self.tensor_names = set(name for name in remote_tensors.keys())
96+
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
97+
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
98+
99+
self.get_tensors = get_remote_tensors
100+
else:
101+
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
102+
self.is_safetensors = len(self.part_names) > 0
103+
if not self.is_safetensors:
104+
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
91105
self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
92106
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
93107
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@@ -393,6 +407,10 @@ def prepare_metadata(self, vocab_only: bool):
393407

394408
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params)
395409

410+
# If we are using HF model id, set the metadata name to the model id
411+
if self.remote_hf_model_id:
412+
self.metadata.name = self.remote_hf_model_id
413+
396414
# Fallback to model directory name if metadata name is still missing
397415
if self.metadata.name is None:
398416
self.metadata.name = self.dir_model.name
@@ -5403,6 +5421,14 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
54035421
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
54045422
return cast(torch.Tensor, lazy)
54055423

5424+
@classmethod
5425+
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
5426+
dtype = cls._dtype_str_map[remote_tensor.dtype]
5427+
shape = remote_tensor.shape
5428+
meta = cls.meta_with_dtype_and_shape(dtype, shape)
5429+
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
5430+
return cast(torch.Tensor, lazy)
5431+
54065432
@classmethod
54075433
def __torch_function__(cls, func, types, args=(), kwargs=None):
54085434
del types # unused
@@ -5480,6 +5506,10 @@ def parse_args() -> argparse.Namespace:
54805506
"--print-supported-models", action="store_true",
54815507
help="Print the supported models"
54825508
)
5509+
parser.add_argument(
5510+
"--remote", action="store_true",
5511+
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
5512+
)
54835513

54845514
args = parser.parse_args()
54855515
if not args.print_supported_models and args.model is None:
@@ -5520,6 +5550,14 @@ def main() -> None:
55205550

55215551
dir_model = args.model
55225552

5553+
if args.remote:
5554+
from huggingface_hub import snapshot_download
5555+
local_dir = snapshot_download(
5556+
repo_id=str(dir_model),
5557+
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
5558+
dir_model = Path(local_dir)
5559+
logger.info(f"Downloaded config and tokenizer to {local_dir}")
5560+
55235561
if not dir_model.is_dir():
55245562
logger.error(f'Error: {args.model} is not a directory')
55255563
sys.exit(1)
@@ -5541,6 +5579,9 @@ def main() -> None:
55415579

55425580
if args.outfile is not None:
55435581
fname_out = args.outfile
5582+
elif args.remote:
5583+
# if remote, use the model ID as the output file name
5584+
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
55445585
else:
55455586
fname_out = dir_model
55465587

@@ -5564,7 +5605,8 @@ def main() -> None:
55645605
metadata_override=args.metadata, model_name=args.model_name,
55655606
split_max_tensors=args.split_max_tensors,
55665607
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
5567-
small_first_shard=args.no_tensor_first_split)
5608+
small_first_shard=args.no_tensor_first_split,
5609+
remote_hf_model_id=str(args.model) if args.remote else None)
55685610

55695611
if args.vocab_only:
55705612
logger.info("Exporting model vocab...")

examples/llava/CMakeLists.txt

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# llava (legacy)
2+
13
add_library(llava OBJECT
24
llava.cpp
35
llava.h
@@ -22,12 +24,41 @@ if (BUILD_SHARED_LIBS)
2224
install(TARGETS llava_shared LIBRARY)
2325
endif()
2426

27+
# mtmd
28+
29+
add_library(mtmd OBJECT
30+
mtmd.cpp
31+
mtmd.h
32+
clip.cpp
33+
clip.h
34+
clip-impl.h
35+
)
36+
37+
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
38+
39+
target_include_directories(mtmd PUBLIC .)
40+
target_include_directories(mtmd PRIVATE ../..)
41+
target_include_directories(mtmd PRIVATE ../../common) # for stb_image.h
42+
43+
target_compile_features(mtmd PRIVATE cxx_std_17)
44+
45+
add_library(mtmd_static STATIC $<TARGET_OBJECTS:mtmd>)
46+
if (BUILD_SHARED_LIBS)
47+
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
48+
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
49+
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
50+
target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
51+
install(TARGETS mtmd_shared LIBRARY)
52+
endif()
53+
2554
if (NOT MSVC)
2655
target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
56+
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
2757
endif()
2858

2959
if(TARGET BUILD_INFO)
3060
add_dependencies(llava BUILD_INFO)
61+
add_dependencies(mtmd BUILD_INFO)
3162
endif()
3263

3364
set(TARGET llama-llava-cli)
@@ -55,7 +86,7 @@ set(TARGET llama-gemma3-cli)
5586
add_executable(${TARGET} gemma3-cli.cpp)
5687
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
5788
install(TARGETS ${TARGET} RUNTIME)
58-
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
89+
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
5990
target_compile_features(${TARGET} PRIVATE cxx_std_17)
6091

6192
set(TARGET llama-llava-clip-quantize-cli)

examples/llava/clip-impl.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#include "ggml.h"
22
#include "gguf.h"
33

4+
#include "clip.h"
5+
46
#include <climits>
57
#include <cstdarg>
68
#include <string>
79
#include <map>
810
#include <sstream>
911
#include <vector>
12+
#include <memory>
1013

1114
// Internal header for clip.cpp
1215

@@ -120,6 +123,23 @@ static projector_type clip_projector_type_from_string(const std::string & str) {
120123
return PROJECTOR_TYPE_UNKNOWN;
121124
}
122125

126+
// RGB uint8 image
127+
struct clip_image_u8 {
128+
int nx;
129+
int ny;
130+
131+
std::vector<uint8_t> buf;
132+
};
133+
134+
// RGB float32 image (NHWC)
135+
// Memory layout: RGBRGBRGB...
136+
struct clip_image_f32 {
137+
int nx;
138+
int ny;
139+
140+
std::vector<float> buf;
141+
};
142+
123143
//
124144
// logging
125145
//
@@ -178,6 +198,28 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
178198
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
179199
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
180200

201+
//
202+
// cpp wrappers
203+
//
204+
205+
struct clip_image_u8_deleter {
206+
void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
207+
};
208+
209+
struct clip_image_f32_deleter {
210+
void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
211+
};
212+
213+
struct clip_image_f32_batch_deleter {
214+
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
215+
};
216+
217+
typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
218+
typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
219+
typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
220+
221+
// TODO @ngxson : we're currently having a naming clash between struct clip_image_size and function clip_image_size()
222+
181223
//
182224
// common utils
183225
//
@@ -214,6 +256,20 @@ static void string_replace_all(std::string & s, const std::string & search, cons
214256
s = std::move(builder);
215257
}
216258

259+
// split string by a `std::string delim` instead of `char delim`
260+
static std::vector<std::string> string_split_str(std::string s, const std::string & delimiter) {
261+
std::vector<std::string> tokens;
262+
size_t pos = 0;
263+
std::string token;
264+
while ((pos = s.find(delimiter)) != std::string::npos) {
265+
token = s.substr(0, pos);
266+
tokens.push_back(token);
267+
s.erase(0, pos + delimiter.length());
268+
}
269+
tokens.push_back(s);
270+
return tokens;
271+
}
272+
217273
//
218274
// gguf utils
219275
//
@@ -271,3 +327,9 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
271327
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
272328
}
273329
}
330+
331+
//
332+
// API used internally with mtmd
333+
//
334+
335+
projector_type clip_get_projector_type(const struct clip_ctx * ctx);

examples/llava/clip.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,6 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
3232

3333
//#define CLIP_DEBUG_FUNCTIONS
3434

35-
// RGB uint8 image
36-
struct clip_image_u8 {
37-
int nx;
38-
int ny;
39-
40-
std::vector<uint8_t> buf;
41-
};
42-
43-
// RGB float32 image (NHWC)
44-
// Memory layout: RGBRGBRGB...
45-
struct clip_image_f32 {
46-
int nx;
47-
int ny;
48-
49-
std::vector<float> buf;
50-
};
51-
5235
#ifdef CLIP_DEBUG_FUNCTIONS
5336
static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
5437
std::ofstream file(filename, std::ios::binary);
@@ -1614,6 +1597,12 @@ struct clip_image_f32 * clip_image_f32_init() {
16141597
return new clip_image_f32();
16151598
}
16161599

1600+
unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
1601+
if (nx) *nx = img->nx;
1602+
if (ny) *ny = img->ny;
1603+
return img->buf.data();
1604+
}
1605+
16171606
void clip_image_size_free(struct clip_image_size * load_image_size) {
16181607
if (load_image_size == nullptr) {
16191608
return;
@@ -2346,6 +2335,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
23462335
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
23472336
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
23482337
n_patches = x_patch * y_patch;
2338+
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
2339+
n_patches = 256;
23492340
}
23502341

23512342
return n_patches;
@@ -2893,3 +2884,11 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
28932884
clip_image_encode(ctx, n_threads, &clip_img, vec);
28942885
return true;
28952886
}
2887+
2888+
//
2889+
// API used internally with mtmd
2890+
//
2891+
2892+
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
2893+
return ctx->proj_type;
2894+
}

examples/llava/clip.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ CLIP_API struct clip_image_size * clip_image_size_init();
7777
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
7878
CLIP_API struct clip_image_f32 * clip_image_f32_init();
7979

80+
// nx, ny are the output image dimensions
81+
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
82+
8083
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
8184
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
8285
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);

0 commit comments

Comments
 (0)