diff --git a/src/config.ts b/src/config.ts index 03be1036..3b6b039d 100644 --- a/src/config.ts +++ b/src/config.ts @@ -91,6 +91,10 @@ export interface ChatConfig { top_p: number; temperature: number; bos_token_id?: number; + // Model type identifier from mlc-chat-config.json (e.g. "phi3_v", "gemma3_v") + model_type?: string; + // Nested model config from mlc-chat-config.json, contains model-specific parameters + model_config?: Record; } /** diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 70941449..74c12b15 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -11,7 +11,6 @@ import { getRGBArrayFromImageData, getTokenTableFromTokenizer, getTopProbs, - IMAGE_EMBED_SIZE, } from "./support"; import { ChatCompletionFinishReason, @@ -29,7 +28,6 @@ import { WindowSizeSpecificationError, MessageOrderError, TextCompletionExpectsKVEmptyError, - PrefillChunkSizeSmallerThanImageError, CannotFindImageEmbedError, } from "./error"; @@ -252,7 +250,6 @@ export class LLMChatPipeline { if (this.prefillChunkSize <= 0) { throw new MinValueError("prefill_chunk_size", 0); } - // 5. Consolidate KVCache settings: context window, sliding window, attention sink this.slidingWindowSize = config.sliding_window_size; this.contextWindowSize = config.context_window_size; @@ -698,23 +695,10 @@ export class LLMChatPipeline { conversation.appendReplyHeader(Role.assistant); } } - const retGetInputData = this.getInputData(); - const inputData: Array | ImageURL> = retGetInputData[0]; - const promptLen: number = retGetInputData[1]; + const [inputData, promptLen, getEmbedSize] = await this.getInputData(); // Check if LLMChatPipeline fits for forwarding image input - let hasImageInput = false; - inputData.forEach((data) => { - if (!Array.isArray(data)) { - hasImageInput = true; - } - }); - if (hasImageInput && this.prefillChunkSize < IMAGE_EMBED_SIZE) { - throw new PrefillChunkSizeSmallerThanImageError( - this.prefillChunkSize, - IMAGE_EMBED_SIZE, - ); - } + const hasImageInput = inputData.some((data) => !Array.isArray(data)); if (hasImageInput && this.image_embed === undefined) { throw new CannotFindImageEmbedError(); } @@ -723,6 +707,7 @@ export class LLMChatPipeline { const retGetChunks = getChunkedPrefillInputData( inputData, this.prefillChunkSize, + getEmbedSize, ); const chunks: Array | ImageURL>[] = retGetChunks[0]; const chunkLens: Array = retGetChunks[1]; @@ -939,6 +924,33 @@ export class LLMChatPipeline { return embed; } + /** + * Compute the number of embedding tokens an image will produce. + * Must match the model's image_embed output size. + * Based on mlc_llm/serve/data.py _compute_embed_size. + */ + private computeImageEmbedSize( + imageHeight: number, + imageWidth: number, + ): number { + const modelType = this.config.model_type; + if (modelType === "phi3_v") { + const [cropH, cropW] = this.calculateCropShape(imageHeight, imageWidth); + const subTokens = cropH * 12 * (cropW * 12 + 1); + const glbTokens = 12 * (12 + 1); + return subTokens + 1 + glbTokens; + } + // For models with fixed embed size (e.g. Gemma3V) + const mmTokens = this.config.model_config?.mm_tokens_per_image; + if (mmTokens !== undefined) { + return mmTokens; + } + throw new Error( + "Cannot determine image embed size. " + + "Please add mm_tokens_per_image to model_config in mlc-chat-config.json.", + ); + } + /** * Calculate resize dimensions for Phi3-V model. * Based on vlm_utils.cc CalculateResizeShape @@ -1018,9 +1030,13 @@ export class LLMChatPipeline { this.params, ), ); - if (embed.shape[0] !== IMAGE_EMBED_SIZE) { + const expectedSize = this.computeImageEmbedSize( + imgData.height, + imgData.width, + ); + if (embed.shape[0] !== expectedSize) { throw new Error( - `InternalError: expect embed.shape[0] to be ${IMAGE_EMBED_SIZE}, ` + + `InternalError: expect embed.shape[0] to be ${expectedSize}, ` + `but got ${embed.shape[0]}`, ); } @@ -1519,7 +1535,9 @@ export class LLMChatPipeline { * imageUrl2, * token array for "\nSome user input<|end|>\n" */ - private getInputData(): [Array | ImageURL>, number] { + private async getInputData(): Promise< + [Array | ImageURL>, number, (image: ImageURL) => number] + > { const ret: Array | ImageURL> = []; let curTokens: Array = []; let prompts: Array>; @@ -1546,6 +1564,32 @@ export class LLMChatPipeline { } } + // 1.5. Preload image dimensions to compute per-image embed sizes + const imageDimensions = new Map(); + const uniqueImageUrls = new Set(); + for (const prompt of prompts) { + if (typeof prompt !== "string") { + for (const content of prompt) { + if (typeof content !== "string") { + uniqueImageUrls.add(content.url); + } + } + } + } + await Promise.all( + Array.from(uniqueImageUrls).map(async (url) => { + const imgData = await getImageDataFromURL(url); + imageDimensions.set(url, [imgData.height, imgData.width]); + }), + ); + const getEmbedSize = (image: ImageURL): number => { + const dims = imageDimensions.get(image.url); + if (!dims) { + throw new Error("InternalError: image dimensions not preloaded"); + } + return this.computeImageEmbedSize(dims[0], dims[1]); + }; + // 2. Encode all prompts. Iterate through each message in the prompt array, where each // prompt can either be a string, or an array of a mixture of string and ImageURLs. let numPromptTokens = 0; @@ -1563,11 +1607,25 @@ export class LLMChatPipeline { numPromptTokens += encoded.length; curTokens.push(...encoded); } else { + // Insert BOI wrapping if configured (e.g. Gemma3V: \n + BOI token before image) + const boiToken = this.config.model_config?.boi_token_index; + if (boiToken !== undefined) { + const nlTokens = this.tokenizer.encode("\n"); + curTokens.push(...nlTokens); + curTokens.push(boiToken); + numPromptTokens += nlTokens.length + 1; + } // push curTokens to ret, push imageUrl, create a new curTokens ret.push([...curTokens]); ret.push(curPromptContent); - numPromptTokens += IMAGE_EMBED_SIZE; + numPromptTokens += getEmbedSize(curPromptContent); curTokens = []; + // Insert EOI token after image if configured + const eoiToken = this.config.model_config?.eoi_token_index; + if (eoiToken !== undefined) { + curTokens.push(eoiToken); + numPromptTokens += 1; + } } } } @@ -1587,7 +1645,7 @@ export class LLMChatPipeline { this.contextWindowSize, ); } - return [ret, numPromptTokens]; + return [ret, numPromptTokens, getEmbedSize]; } async forwardTokensAndSample( @@ -1601,6 +1659,7 @@ export class LLMChatPipeline { const retGetChunks = getChunkedPrefillInputData( inputData, this.prefillChunkSize, + () => 0, // text-only path, no images ); const chunks: Array | ImageURL>[] = retGetChunks[0]; const chunkLens: Array = retGetChunks[1]; diff --git a/src/support.ts b/src/support.ts index a2ae0d04..dcc969e7 100644 --- a/src/support.ts +++ b/src/support.ts @@ -14,6 +14,7 @@ import { ToolCallOutputMissingFieldsError, ToolCallOutputParseError, UnclearModelToUseError, + PrefillChunkSizeSmallerThanImageError, } from "./error"; /** @@ -278,11 +279,12 @@ export function getModelIdToUse( * Chunk the inputData such that each chunk's total input length is smaller than prefill * chunk size. * @returns [the data chunks, the input length of each chunk] - * @note precondition: if inputData has image in it, then prefillChunkSize >= IMAGE_EMBED_SIZE. + * @note precondition: if inputData has image in it, then prefillChunkSize >= imageEmbedSize. */ export function getChunkedPrefillInputData( inputData: Array | ImageURL>, prefillChunkSize: number, + getImageEmbedSize: (image: ImageURL) => number, ): [Array | ImageURL>[], Array] { const chunks: Array | ImageURL>[] = []; const chunkLens: Array = []; @@ -292,7 +294,13 @@ export function getChunkedPrefillInputData( let curData: Array | ImageURL = inputData[i]; const curDataLen = Array.isArray(curData) ? curData.length - : IMAGE_EMBED_SIZE; + : getImageEmbedSize(curData); + if (!Array.isArray(curData) && curDataLen > prefillChunkSize) { + throw new PrefillChunkSizeSmallerThanImageError( + prefillChunkSize, + curDataLen, + ); + } // 1. curData can fit into this chunk if (curChunkLen + curDataLen <= prefillChunkSize) { curChunk.push(curData); @@ -338,7 +346,7 @@ export function getChunkedPrefillInputData( chunkLens.push(curChunkLen); // 2.2.2. Then push image to the new chunk curChunk = [curData]; - curChunkLen = IMAGE_EMBED_SIZE; + curChunkLen = curDataLen; if (curChunkLen === prefillChunkSize) { chunks.push([...curChunk]); chunkLens.push(curChunkLen); @@ -405,9 +413,6 @@ export class CustomLock { // Image related type ImageURL = ChatCompletionContentPartImage.ImageURL; -// TODO(Charlie): currently hardcoded for phi3.5-vision num_crops 16 -export const IMAGE_EMBED_SIZE = 1921; - /** * Given a url, get the image data. The url can either start with `http` or `data:image`. */ diff --git a/tests/llm_chat_pipeline.test.ts b/tests/llm_chat_pipeline.test.ts index 6c437f8a..04ee6ddb 100644 --- a/tests/llm_chat_pipeline.test.ts +++ b/tests/llm_chat_pipeline.test.ts @@ -201,10 +201,9 @@ function preparePrefillPipeline(): PipelineLike { const pipeline = createPipeline(); pipeline["prefillTotalTime"] = 0; pipeline["prefillTotalTokens"] = 0; - pipeline["getInputData"] = jest.fn<() => [number[][], number]>(() => [ - [[0]], - 1, - ]); + pipeline["getInputData"] = jest.fn( + async (): Promise<[any[], number, any]> => [[[0]], 1, () => 0], + ); pipeline["processNextToken"] = jest.fn(); return pipeline; } @@ -272,15 +271,15 @@ test("prefillStep compiles custom grammar when response type is grammar", async expect(compileGrammarMock).toHaveBeenCalledWith("root ::= WORD"); }); -test("getInputData uses cached prompts when KV cache filled", () => { +test("getInputData uses cached prompts when KV cache filled", async () => { const pipeline = createPipeline(); pipeline["tokenizer"].encode = jest.fn(() => Int32Array.from([1])); pipeline["conversation"].config.system_prefix_token_ids = undefined; pipeline["filledKVCacheLength"] = 0; - (pipeline as any).getInputData(); + await (pipeline as any).getInputData(); expect(pipeline["conversation"].getPromptArray).toHaveBeenCalled(); pipeline["filledKVCacheLength"] = 1; - (pipeline as any).getInputData(); + await (pipeline as any).getInputData(); expect(pipeline["conversation"].getPromptArrayLastRound).toHaveBeenCalled(); }); @@ -292,3 +291,74 @@ test("processNextToken ignores eos when requested", () => { expect(pipeline["finishReason"]).toBeUndefined(); expect(pipeline["outputIds"]).toContain(1); }); + +describe("calculateResizeShape", () => { + test("square image", () => { + const pipeline = createPipeline(); + expect(pipeline["calculateResizeShape"](336, 336)).toEqual([1344, 1344]); + }); + + test("landscape image", () => { + const pipeline = createPipeline(); + expect(pipeline["calculateResizeShape"](1080, 1920)).toEqual([945, 1680]); + }); + + test("portrait image", () => { + const pipeline = createPipeline(); + expect(pipeline["calculateResizeShape"](1920, 1080)).toEqual([1194, 672]); + }); +}); + +describe("calculateCropShape", () => { + test("square image", () => { + const pipeline = createPipeline(); + expect(pipeline["calculateCropShape"](336, 336)).toEqual([4, 4]); + }); + + test("landscape image", () => { + const pipeline = createPipeline(); + expect(pipeline["calculateCropShape"](1080, 1920)).toEqual([3, 5]); + }); + + test("portrait image", () => { + const pipeline = createPipeline(); + expect(pipeline["calculateCropShape"](1920, 1080)).toEqual([4, 2]); + }); +}); + +describe("computeImageEmbedSize", () => { + test("phi3_v square image", () => { + const pipeline = createPipeline(); + pipeline["config"] = { model_type: "phi3_v" } as any; + expect(pipeline["computeImageEmbedSize"](336, 336)).toBe(2509); + }); + + test("phi3_v landscape image", () => { + const pipeline = createPipeline(); + pipeline["config"] = { model_type: "phi3_v" } as any; + expect(pipeline["computeImageEmbedSize"](1080, 1920)).toBe(2353); + }); + + test("phi3_v portrait image", () => { + const pipeline = createPipeline(); + pipeline["config"] = { model_type: "phi3_v" } as any; + expect(pipeline["computeImageEmbedSize"](1920, 1080)).toBe(1357); + }); + + test("model with mm_tokens_per_image", () => { + const pipeline = createPipeline(); + pipeline["config"] = { + model_type: "gemma3_v", + model_config: { mm_tokens_per_image: 256 }, + } as any; + expect(pipeline["computeImageEmbedSize"](1080, 1920)).toBe(256); + }); + + test("unknown model without mm_tokens throws", () => { + const pipeline = createPipeline(); + pipeline["config"] = { model_type: "unknown_model" } as any; + expect(() => pipeline["computeImageEmbedSize"](336, 336)).toThrow( + "Cannot determine image embed size", + ); + }); +}); diff --git a/tests/util.test.ts b/tests/util.test.ts index bc29c2e2..df6bb590 100644 --- a/tests/util.test.ts +++ b/tests/util.test.ts @@ -1,6 +1,7 @@ import { ChatOptions } from "../src/config"; import { ModelNotLoadedError, + PrefillChunkSizeSmallerThanImageError, SpecifiedModelNotFoundError, UnclearModelToUseError, } from "../src/error"; @@ -329,6 +330,7 @@ describe("Test getChunkedPrefillInputData", () => { const prefillChunkSize = 2048; const image1 = { url: "url1" } as ImageURL; const image2 = { url: "url2" } as ImageURL; + const getImageEmbedSize = () => 1921; test("With image data", async () => { const inputData = [ @@ -336,7 +338,11 @@ describe("Test getChunkedPrefillInputData", () => { image1, // 1921 size rangeArr(0, 10), ]; - const chunks = getChunkedPrefillInputData(inputData, prefillChunkSize); + const chunks = getChunkedPrefillInputData( + inputData, + prefillChunkSize, + getImageEmbedSize, + ); const expectedChunks = [[rangeArr(0, 200)], [image1, rangeArr(0, 10)]]; const expectedChunkLens = [200, 1931]; expect(chunks).toEqual([expectedChunks, expectedChunkLens]); @@ -344,7 +350,11 @@ describe("Test getChunkedPrefillInputData", () => { test("Single image data", async () => { const inputData = [image1]; - const chunks = getChunkedPrefillInputData(inputData, prefillChunkSize); + const chunks = getChunkedPrefillInputData( + inputData, + prefillChunkSize, + getImageEmbedSize, + ); const expectedChunks = [[image1]]; const expectedChunkLens = [1921]; expect(chunks).toEqual([expectedChunks, expectedChunkLens]); @@ -352,7 +362,11 @@ describe("Test getChunkedPrefillInputData", () => { test("Two images", async () => { const inputData = [image1, image2]; - const chunks = getChunkedPrefillInputData(inputData, prefillChunkSize); + const chunks = getChunkedPrefillInputData( + inputData, + prefillChunkSize, + getImageEmbedSize, + ); const expectedChunks = [[image1], [image2]]; const expectedChunkLens = [1921, 1921]; expect(chunks).toEqual([expectedChunks, expectedChunkLens]); @@ -360,7 +374,11 @@ describe("Test getChunkedPrefillInputData", () => { test("Single token array that needs to be chunked", async () => { const inputData = [rangeArr(0, 4097)]; - const chunks = getChunkedPrefillInputData(inputData, prefillChunkSize); + const chunks = getChunkedPrefillInputData( + inputData, + prefillChunkSize, + getImageEmbedSize, + ); const expectedChunks = [ [rangeArr(0, 2048)], [rangeArr(2048, 4096)], @@ -372,7 +390,11 @@ describe("Test getChunkedPrefillInputData", () => { test("Single token array that does not need to be chunked", async () => { const inputData = [rangeArr(0, 2048)]; - const chunks = getChunkedPrefillInputData(inputData, prefillChunkSize); + const chunks = getChunkedPrefillInputData( + inputData, + prefillChunkSize, + getImageEmbedSize, + ); const expectedChunks = [[rangeArr(0, 2048)]]; const expectedChunkLens = [2048]; expect(chunks).toEqual([expectedChunks, expectedChunkLens]); @@ -384,7 +406,11 @@ describe("Test getChunkedPrefillInputData", () => { rangeArr(0, 2300), image2, ]; - const chunks = getChunkedPrefillInputData(inputData, prefillChunkSize); + const chunks = getChunkedPrefillInputData( + inputData, + prefillChunkSize, + getImageEmbedSize, + ); const expectedChunks = [ [image1, rangeArr(0, 127)], // 127 = 2048 - 1921 [rangeArr(127, 2175)], // 2175 = 127 + 2048 @@ -400,11 +426,44 @@ describe("Test getChunkedPrefillInputData", () => { rangeArr(0, 127), image2, ]; - const chunks = getChunkedPrefillInputData(inputData, prefillChunkSize); + const chunks = getChunkedPrefillInputData( + inputData, + prefillChunkSize, + getImageEmbedSize, + ); const expectedChunks = [[image1, rangeArr(0, 127)], [image2]]; const expectedChunkLens = [2048, 1921]; expect(chunks).toEqual([expectedChunks, expectedChunkLens]); }); + + test("Throws when image embed size exceeds prefill chunk size", () => { + const inputData = [image1]; + expect(() => + getChunkedPrefillInputData(inputData, 100, () => 1921), + ).toThrow(PrefillChunkSizeSmallerThanImageError); + }); + + test("Dynamic per-image embed sizes", () => { + const sizeMap: Record = { url1: 500, url2: 1500 }; + const getDynamicSize = (img: ImageURL) => sizeMap[img.url]; + const inputData = [ + rangeArr(0, 100), + image1, // 500 + rangeArr(0, 50), + image2, // 1500 + ]; + const chunks = getChunkedPrefillInputData( + inputData, + prefillChunkSize, + getDynamicSize, + ); + const expectedChunks = [ + [rangeArr(0, 100), image1, rangeArr(0, 50)], + [image2], + ]; + const expectedChunkLens = [650, 1500]; + expect(chunks).toEqual([expectedChunks, expectedChunkLens]); + }); }); // Refers to https://jackpordi.com/posts/locks-in-js-because-why-not