Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ export interface ChatConfig {
top_p: number;
temperature: number;
bos_token_id?: number;
// Model type identifier from mlc-chat-config.json (e.g. "phi3_v", "gemma3_v")
model_type?: string;
// Nested model config from mlc-chat-config.json, contains model-specific parameters
model_config?: Record<string, any>;
}

/**
Expand Down
105 changes: 82 additions & 23 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
getRGBArrayFromImageData,
getTokenTableFromTokenizer,
getTopProbs,
IMAGE_EMBED_SIZE,
} from "./support";
import {
ChatCompletionFinishReason,
Expand All @@ -29,7 +28,6 @@ import {
WindowSizeSpecificationError,
MessageOrderError,
TextCompletionExpectsKVEmptyError,
PrefillChunkSizeSmallerThanImageError,
CannotFindImageEmbedError,
} from "./error";

Expand Down Expand Up @@ -252,7 +250,6 @@ export class LLMChatPipeline {
if (this.prefillChunkSize <= 0) {
throw new MinValueError("prefill_chunk_size", 0);
}

// 5. Consolidate KVCache settings: context window, sliding window, attention sink
this.slidingWindowSize = config.sliding_window_size;
this.contextWindowSize = config.context_window_size;
Expand Down Expand Up @@ -698,23 +695,10 @@ export class LLMChatPipeline {
conversation.appendReplyHeader(Role.assistant);
}
}
const retGetInputData = this.getInputData();
const inputData: Array<Array<number> | ImageURL> = retGetInputData[0];
const promptLen: number = retGetInputData[1];
const [inputData, promptLen, getEmbedSize] = await this.getInputData();

// Check if LLMChatPipeline fits for forwarding image input
let hasImageInput = false;
inputData.forEach((data) => {
if (!Array.isArray(data)) {
hasImageInput = true;
}
});
if (hasImageInput && this.prefillChunkSize < IMAGE_EMBED_SIZE) {
throw new PrefillChunkSizeSmallerThanImageError(
this.prefillChunkSize,
IMAGE_EMBED_SIZE,
);
}
const hasImageInput = inputData.some((data) => !Array.isArray(data));
if (hasImageInput && this.image_embed === undefined) {
throw new CannotFindImageEmbedError();
}
Expand All @@ -723,6 +707,7 @@ export class LLMChatPipeline {
const retGetChunks = getChunkedPrefillInputData(
inputData,
this.prefillChunkSize,
getEmbedSize,
);
const chunks: Array<Array<number> | ImageURL>[] = retGetChunks[0];
const chunkLens: Array<number> = retGetChunks[1];
Expand Down Expand Up @@ -939,6 +924,33 @@ export class LLMChatPipeline {
return embed;
}

/**
* Compute the number of embedding tokens an image will produce.
* Must match the model's image_embed output size.
* Based on mlc_llm/serve/data.py _compute_embed_size.
*/
private computeImageEmbedSize(
imageHeight: number,
imageWidth: number,
): number {
const modelType = this.config.model_type;
if (modelType === "phi3_v") {
const [cropH, cropW] = this.calculateCropShape(imageHeight, imageWidth);
const subTokens = cropH * 12 * (cropW * 12 + 1);
const glbTokens = 12 * (12 + 1);
return subTokens + 1 + glbTokens;
}
// For models with fixed embed size (e.g. Gemma3V)
const mmTokens = this.config.model_config?.mm_tokens_per_image;
if (mmTokens !== undefined) {
return mmTokens;
}
throw new Error(
"Cannot determine image embed size. " +
"Please add mm_tokens_per_image to model_config in mlc-chat-config.json.",
);
}

/**
* Calculate resize dimensions for Phi3-V model.
* Based on vlm_utils.cc CalculateResizeShape
Expand Down Expand Up @@ -1018,9 +1030,13 @@ export class LLMChatPipeline {
this.params,
),
);
if (embed.shape[0] !== IMAGE_EMBED_SIZE) {
const expectedSize = this.computeImageEmbedSize(
imgData.height,
imgData.width,
);
if (embed.shape[0] !== expectedSize) {
throw new Error(
`InternalError: expect embed.shape[0] to be ${IMAGE_EMBED_SIZE}, ` +
`InternalError: expect embed.shape[0] to be ${expectedSize}, ` +
`but got ${embed.shape[0]}`,
);
}
Expand Down Expand Up @@ -1519,7 +1535,9 @@ export class LLMChatPipeline {
* imageUrl2,
* token array for "\nSome user input<|end|>\n"
*/
private getInputData(): [Array<Array<number> | ImageURL>, number] {
private async getInputData(): Promise<
[Array<Array<number> | ImageURL>, number, (image: ImageURL) => number]
> {
const ret: Array<Array<number> | ImageURL> = [];
let curTokens: Array<number> = [];
let prompts: Array<string | Array<string | ImageURL>>;
Expand All @@ -1546,6 +1564,32 @@ export class LLMChatPipeline {
}
}

// 1.5. Preload image dimensions to compute per-image embed sizes
const imageDimensions = new Map<string, [number, number]>();
const uniqueImageUrls = new Set<string>();
for (const prompt of prompts) {
if (typeof prompt !== "string") {
for (const content of prompt) {
if (typeof content !== "string") {
uniqueImageUrls.add(content.url);
}
}
}
}
await Promise.all(
Array.from(uniqueImageUrls).map(async (url) => {
const imgData = await getImageDataFromURL(url);
imageDimensions.set(url, [imgData.height, imgData.width]);
}),
);
const getEmbedSize = (image: ImageURL): number => {
const dims = imageDimensions.get(image.url);
if (!dims) {
throw new Error("InternalError: image dimensions not preloaded");
}
return this.computeImageEmbedSize(dims[0], dims[1]);
};

// 2. Encode all prompts. Iterate through each message in the prompt array, where each
// prompt can either be a string, or an array of a mixture of string and ImageURLs.
let numPromptTokens = 0;
Expand All @@ -1563,11 +1607,25 @@ export class LLMChatPipeline {
numPromptTokens += encoded.length;
curTokens.push(...encoded);
} else {
// Insert BOI wrapping if configured (e.g. Gemma3V: \n + BOI token before image)
const boiToken = this.config.model_config?.boi_token_index;
if (boiToken !== undefined) {
const nlTokens = this.tokenizer.encode("\n");
curTokens.push(...nlTokens);
curTokens.push(boiToken);
numPromptTokens += nlTokens.length + 1;
}
// push curTokens to ret, push imageUrl, create a new curTokens
ret.push([...curTokens]);
ret.push(curPromptContent);
numPromptTokens += IMAGE_EMBED_SIZE;
numPromptTokens += getEmbedSize(curPromptContent);
curTokens = [];
// Insert EOI token after image if configured
const eoiToken = this.config.model_config?.eoi_token_index;
if (eoiToken !== undefined) {
curTokens.push(eoiToken);
numPromptTokens += 1;
}
}
}
}
Expand All @@ -1587,7 +1645,7 @@ export class LLMChatPipeline {
this.contextWindowSize,
);
}
return [ret, numPromptTokens];
return [ret, numPromptTokens, getEmbedSize];
}

async forwardTokensAndSample(
Expand All @@ -1601,6 +1659,7 @@ export class LLMChatPipeline {
const retGetChunks = getChunkedPrefillInputData(
inputData,
this.prefillChunkSize,
() => 0, // text-only path, no images
);
const chunks: Array<Array<number> | ImageURL>[] = retGetChunks[0];
const chunkLens: Array<number> = retGetChunks[1];
Expand Down
17 changes: 11 additions & 6 deletions src/support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
ToolCallOutputMissingFieldsError,
ToolCallOutputParseError,
UnclearModelToUseError,
PrefillChunkSizeSmallerThanImageError,
} from "./error";

/**
Expand Down Expand Up @@ -278,11 +279,12 @@ export function getModelIdToUse(
* Chunk the inputData such that each chunk's total input length is smaller than prefill
* chunk size.
* @returns [the data chunks, the input length of each chunk]
* @note precondition: if inputData has image in it, then prefillChunkSize >= IMAGE_EMBED_SIZE.
* @note precondition: if inputData has image in it, then prefillChunkSize >= imageEmbedSize.
*/
export function getChunkedPrefillInputData(
inputData: Array<Array<number> | ImageURL>,
prefillChunkSize: number,
getImageEmbedSize: (image: ImageURL) => number,
): [Array<Array<number> | ImageURL>[], Array<number>] {
const chunks: Array<Array<number> | ImageURL>[] = [];
const chunkLens: Array<number> = [];
Expand All @@ -292,7 +294,13 @@ export function getChunkedPrefillInputData(
let curData: Array<number> | ImageURL = inputData[i];
const curDataLen = Array.isArray(curData)
? curData.length
: IMAGE_EMBED_SIZE;
: getImageEmbedSize(curData);
if (!Array.isArray(curData) && curDataLen > prefillChunkSize) {
throw new PrefillChunkSizeSmallerThanImageError(
prefillChunkSize,
curDataLen,
);
}
// 1. curData can fit into this chunk
if (curChunkLen + curDataLen <= prefillChunkSize) {
curChunk.push(curData);
Expand Down Expand Up @@ -338,7 +346,7 @@ export function getChunkedPrefillInputData(
chunkLens.push(curChunkLen);
// 2.2.2. Then push image to the new chunk
curChunk = [curData];
curChunkLen = IMAGE_EMBED_SIZE;
curChunkLen = curDataLen;
if (curChunkLen === prefillChunkSize) {
chunks.push([...curChunk]);
chunkLens.push(curChunkLen);
Expand Down Expand Up @@ -405,9 +413,6 @@ export class CustomLock {
// Image related
type ImageURL = ChatCompletionContentPartImage.ImageURL;

// TODO(Charlie): currently hardcoded for phi3.5-vision num_crops 16
export const IMAGE_EMBED_SIZE = 1921;

/**
* Given a url, get the image data. The url can either start with `http` or `data:image`.
*/
Expand Down
84 changes: 77 additions & 7 deletions tests/llm_chat_pipeline.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,9 @@ function preparePrefillPipeline(): PipelineLike {
const pipeline = createPipeline();
pipeline["prefillTotalTime"] = 0;
pipeline["prefillTotalTokens"] = 0;
pipeline["getInputData"] = jest.fn<() => [number[][], number]>(() => [
[[0]],
1,
]);
pipeline["getInputData"] = jest.fn(
async (): Promise<[any[], number, any]> => [[[0]], 1, () => 0],
);
pipeline["processNextToken"] = jest.fn();
return pipeline;
}
Expand Down Expand Up @@ -272,15 +271,15 @@ test("prefillStep compiles custom grammar when response type is grammar", async
expect(compileGrammarMock).toHaveBeenCalledWith("root ::= WORD");
});

test("getInputData uses cached prompts when KV cache filled", () => {
test("getInputData uses cached prompts when KV cache filled", async () => {
const pipeline = createPipeline();
pipeline["tokenizer"].encode = jest.fn(() => Int32Array.from([1]));
pipeline["conversation"].config.system_prefix_token_ids = undefined;
pipeline["filledKVCacheLength"] = 0;
(pipeline as any).getInputData();
await (pipeline as any).getInputData();
expect(pipeline["conversation"].getPromptArray).toHaveBeenCalled();
pipeline["filledKVCacheLength"] = 1;
(pipeline as any).getInputData();
await (pipeline as any).getInputData();
expect(pipeline["conversation"].getPromptArrayLastRound).toHaveBeenCalled();
});

Expand All @@ -292,3 +291,74 @@ test("processNextToken ignores eos when requested", () => {
expect(pipeline["finishReason"]).toBeUndefined();
expect(pipeline["outputIds"]).toContain(1);
});

describe("calculateResizeShape", () => {
test("square image", () => {
const pipeline = createPipeline();
expect(pipeline["calculateResizeShape"](336, 336)).toEqual([1344, 1344]);
});

test("landscape image", () => {
const pipeline = createPipeline();
expect(pipeline["calculateResizeShape"](1080, 1920)).toEqual([945, 1680]);
});

test("portrait image", () => {
const pipeline = createPipeline();
expect(pipeline["calculateResizeShape"](1920, 1080)).toEqual([1194, 672]);
});
});

describe("calculateCropShape", () => {
test("square image", () => {
const pipeline = createPipeline();
expect(pipeline["calculateCropShape"](336, 336)).toEqual([4, 4]);
});

test("landscape image", () => {
const pipeline = createPipeline();
expect(pipeline["calculateCropShape"](1080, 1920)).toEqual([3, 5]);
});

test("portrait image", () => {
const pipeline = createPipeline();
expect(pipeline["calculateCropShape"](1920, 1080)).toEqual([4, 2]);
});
});

describe("computeImageEmbedSize", () => {
test("phi3_v square image", () => {
const pipeline = createPipeline();
pipeline["config"] = { model_type: "phi3_v" } as any;
expect(pipeline["computeImageEmbedSize"](336, 336)).toBe(2509);
});

test("phi3_v landscape image", () => {
const pipeline = createPipeline();
pipeline["config"] = { model_type: "phi3_v" } as any;
expect(pipeline["computeImageEmbedSize"](1080, 1920)).toBe(2353);
});

test("phi3_v portrait image", () => {
const pipeline = createPipeline();
pipeline["config"] = { model_type: "phi3_v" } as any;
expect(pipeline["computeImageEmbedSize"](1920, 1080)).toBe(1357);
});

test("model with mm_tokens_per_image", () => {
const pipeline = createPipeline();
pipeline["config"] = {
model_type: "gemma3_v",
model_config: { mm_tokens_per_image: 256 },
} as any;
expect(pipeline["computeImageEmbedSize"](1080, 1920)).toBe(256);
});

test("unknown model without mm_tokens throws", () => {
const pipeline = createPipeline();
pipeline["config"] = { model_type: "unknown_model" } as any;
expect(() => pipeline["computeImageEmbedSize"](336, 336)).toThrow(
"Cannot determine image embed size",
);
});
});
Loading