Skip to content

Commit 9fa1b27

Browse files
authored
LlmModule prefill refactor (#14100)
No longer expose prefill API, but provide multiple inputs. To be consistent with multimodal runner API. Add a new API to force reset the context.
1 parent a4b0822 commit 9fa1b27

File tree

2 files changed

+50
-36
lines changed

2 files changed

+50
-36
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,20 +173,23 @@ public native int generate(
173173
* @param height Input image height
174174
* @param channels Input image number of channels
175175
* @param startPos The starting position in KV cache of the input in the LLM.
176-
* @return The updated starting position in KV cache of the input in the LLM.
176+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
177+
* exposed to user.
177178
* @throws RuntimeException if the prefill failed
178179
*/
180+
@Deprecated
179181
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
180-
long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos);
181-
if (nativeResult[0] != 0) {
182-
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
182+
if (startPos == 0) {
183+
resetContext();
183184
}
184-
return nativeResult[1];
185+
int nativeResult = appendImagesInput(image, width, height, channels);
186+
if (nativeResult != 0) {
187+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
188+
}
189+
return 0;
185190
}
186191

187-
// returns a tuple of (status, updated startPos)
188-
private native long[] prefillImagesNative(
189-
int[] image, int width, int height, int channels, long startPos);
192+
private native int appendImagesInput(int[] image, int width, int height, int channels);
190193

191194
/**
192195
* Prefill an LLaVA Module with the given text input.
@@ -196,33 +199,48 @@ private native long[] prefillImagesNative(
196199
* reference and will be updated inside this function.
197200
* @param bos The number of BOS (begin of sequence) token.
198201
* @param eos The number of EOS (end of sequence) token.
199-
* @return The updated starting position in KV cache of the input in the LLM.
202+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
203+
* exposed to user.
200204
* @throws RuntimeException if the prefill failed
201205
*/
206+
@Deprecated
202207
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
203-
long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos);
204-
if (nativeResult[0] != 0) {
205-
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
208+
if (startPos == 0) {
209+
resetContext();
206210
}
207-
return nativeResult[1];
211+
int nativeResult = appendTextInput(prompt, bos, eos);
212+
if (nativeResult != 0) {
213+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
214+
}
215+
return 0;
208216
}
209217

210218
// returns a tuple of (status, updated startPos)
211-
private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);
219+
private native int appendTextInput(String prompt, int bos, int eos);
212220

213221
/**
214222
* Generate tokens from the given prompt, starting from the given position.
215223
*
224+
* <p>This is a deprecated API. Please use {@link #generate(String, int, LlmCallback, boolean)}
225+
*
216226
* @param prompt The text prompt to LLaVA.
217227
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
218228
* @param startPos The starting position in KV cache of the input in the LLM.
219229
* @param callback callback object to receive results.
220230
* @param echo indicate whether to echo the input prompt or not.
221231
* @return The error code.
222232
*/
233+
@Deprecated
223234
public native int generateFromPos(
224235
String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo);
225236

237+
/**
238+
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
239+
*
240+
* <p>The startPos will be reset to 0.
241+
*/
242+
public native void resetContext();
243+
226244
/** Stop current generate() before it finishes. */
227245
@DoNotStrip
228246
public native void stop();

extension/android/jni/jni_layer_llama.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -260,28 +260,19 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
260260
// Returns a tuple of (error, start_pos)
261261
// Contract is valid within an AAR (JNI + corresponding Java code)
262262
// If the first element is not Error::Ok, the other element is undefined.
263-
facebook::jni::local_ref<jlongArray> prefill_prompt(
263+
jint append_text_input(
264264
facebook::jni::alias_ref<jstring> prompt,
265-
jlong start_pos,
266265
jint bos,
267266
jint eos) {
268267
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
269-
facebook::jni::local_ref<jlongArray> tuple_result =
270-
facebook::jni::make_long_array(2);
271-
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
272-
return tuple_result;
268+
return 0;
273269
}
274270

275-
// Returns a tuple of (error, start_pos)
276-
// Contract is valid within an AAR (JNI + corresponding Java code)
277-
// If the first element is not Error::Ok, the other element is undefined.
278-
279-
facebook::jni::local_ref<jlongArray> prefill_images(
271+
jint append_images_input(
280272
facebook::jni::alias_ref<jintArray> image,
281273
jint width,
282274
jint height,
283-
jint channels,
284-
jlong start_pos) {
275+
jint channels) {
285276
std::vector<llm::Image> images;
286277
auto image_size = image->size();
287278
if (image_size != 0) {
@@ -296,11 +287,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
296287
llm::MultimodalInput{std::move(image_runner)});
297288
}
298289

299-
facebook::jni::local_ref<jlongArray> tuple_result =
300-
facebook::jni::make_long_array(2);
301-
302-
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
303-
return tuple_result;
290+
return 0;
304291
}
305292

306293
jint generate_from_pos(
@@ -325,9 +312,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
325312
.seq_len = seq_len,
326313
.temperature = temperature_,
327314
};
328-
return static_cast<jint>(runner_->generate_from_pos(
315+
return static_cast<jint>(runner_->generate(
329316
prompt->toStdString(),
330-
start_pos,
331317
config,
332318
[callback](std::string result) { callback->onResult(result); },
333319
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
@@ -343,6 +329,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
343329
}
344330
}
345331

332+
void reset_context() {
333+
if (runner_ != nullptr) {
334+
runner_->reset();
335+
}
336+
if (multi_modal_runner_ != nullptr) {
337+
multi_modal_runner_->reset();
338+
}
339+
}
340+
346341
jint load() {
347342
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
348343
return static_cast<jint>(multi_modal_runner_->load());
@@ -359,11 +354,12 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
359354
makeNativeMethod("stop", ExecuTorchLlmJni::stop),
360355
makeNativeMethod("load", ExecuTorchLlmJni::load),
361356
makeNativeMethod(
362-
"prefillImagesNative", ExecuTorchLlmJni::prefill_images),
357+
"appendImagesInput", ExecuTorchLlmJni::append_images_input),
363358
makeNativeMethod(
364-
"prefillPromptNative", ExecuTorchLlmJni::prefill_prompt),
359+
"appendTextInput", ExecuTorchLlmJni::append_text_input),
365360
makeNativeMethod(
366361
"generateFromPos", ExecuTorchLlmJni::generate_from_pos),
362+
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),
367363
});
368364
}
369365
};

0 commit comments

Comments
 (0)