Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2289,18 +2289,21 @@ def set_gguf_parameters(self):
)
class LlavaVisionModel(MmprojModel):
img_break_tok_id = -1
use_break_tok = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams.get("model_type") == "pixtral":
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
if self.use_break_tok:
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
elif self.is_mistral_format:
# hparams is already vision config here so norm_eps is only defined in global_config.
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
if self.use_break_tok:
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
logger.info(f"Image break token id: {self.img_break_tok_id}")
Expand Down Expand Up @@ -3791,6 +3794,10 @@ def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
return torch.stack([true_row, false_row], dim=0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "model.vision_" in name:
# skip multimodal tensors
return []

if self.is_rerank:
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
is_real_head = not self.is_tied_embeddings and "lm_head" in name
Expand Down Expand Up @@ -9280,6 +9287,21 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
return super().map_tensor_name(name, try_suffixes)


@ModelBase.register("LightOnOCRForConditionalGeneration")
class LightOnOCRVisionModel(LlavaVisionModel):
is_mistral_format = False
use_break_tok = False

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("model.vision_encoder.", "vision_tower.")
name = name.replace("model.vision_projection.", "multi_modal_projector.")
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("KimiVLForConditionalGeneration")
class KimiVLModel(MmprojModel):
def __init__(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3062,6 +3062,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
LIGHTONOCR = "lightonocr"


# Items here are (block size, type size)
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ enum projector_type {
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_UNKNOWN,
};

Expand All @@ -161,6 +162,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
26 changes: 23 additions & 3 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ struct clip_graph {
}

// arrangement of the [IMG_BREAK] token
{
if (model.token_embd_img_break) {
// not efficient, but works
// the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
Expand Down Expand Up @@ -2095,6 +2095,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
res = graph.build_siglip();
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
{
res = graph.build_pixtral();
} break;
Expand Down Expand Up @@ -2380,6 +2381,7 @@ struct clip_model_loader {
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
{
hparams.rope_theta = 10000.0f;
hparams.warmup_image_size = hparams.patch_size * 8;
Expand Down Expand Up @@ -2722,6 +2724,15 @@ struct clip_model_loader {
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
case PROJECTOR_TYPE_LIGHTONOCR:
{
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
Expand Down Expand Up @@ -3622,7 +3633,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;

} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
|| ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR
) {
clip_image_u8 resized_image;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
Expand Down Expand Up @@ -3865,12 +3878,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
{
// dynamic size
int n_merge = params.spatial_merge_size;
int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1);
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
if (ctx->model.token_embd_img_break) {
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} else {
n_patches = n_patches_y * n_patches_x;
}
} break;
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_ULTRAVOX:
Expand Down Expand Up @@ -4247,6 +4265,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_LIGHTONOCR:
{
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
Expand Down Expand Up @@ -4377,6 +4396,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->model.mm_3_b->ne[0];
Expand Down
5 changes: 5 additions & 0 deletions tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ struct mtmd_context {
img_beg = "<img>";
img_end = "</img>";

} else if (proj == PROJECTOR_TYPE_LIGHTONOCR) {
// <|im_start|> ... (image embeddings) ... <|im_end|>
img_beg = "<|im_start|>";
img_end = "<|im_end|>";

}
}

Expand Down
Loading