Skip to content

Commit 626605b

Browse files
committed
streaming: support customizing model name and max tokens
For fine-tune models, we need a different model name, and the default max token calculation won't work for those.
1 parent 8683516 commit 626605b

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

streaming.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,28 @@ def streaming_callback_manager(new_token_handler: Callable) -> CallbackManager:
3030
return CallbackManager([StreamingCallbackHandler(new_token_handler)])
3131

3232

33-
def get_streaming_llm(new_token_handler):
33+
def get_streaming_llm(new_token_handler, model_name=None, max_tokens=-1):
3434
# falls back to non-streaming if none provided
3535
streaming_kwargs = dict(
3636
streaming=True,
3737
callback_manager=streaming_callback_manager(new_token_handler),
3838
) if new_token_handler else {}
3939

40+
model_kwargs = dict(
41+
model_name=model_name,
42+
) if model_name else {}
43+
4044
llm = OpenAI(
41-
temperature=0.0, max_tokens=-1,
42-
**streaming_kwargs
45+
temperature=0.0,
46+
max_tokens=max_tokens,
47+
**streaming_kwargs,
48+
**model_kwargs,
4349
)
4450
return llm
4551

4652

47-
def get_streaming_chain(prompt, new_token_handler, use_api_chain=False):
48-
llm = get_streaming_llm(new_token_handler)
53+
def get_streaming_chain(prompt, new_token_handler, use_api_chain=False, model_name=None, max_tokens=-1):
54+
llm = get_streaming_llm(new_token_handler, model_name=model_name, max_tokens=max_tokens)
4955

5056
if use_api_chain:
5157
return IndexAPIChain.from_llm(

0 commit comments

Comments
 (0)