-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Allow custom type to be streamed and use native response #8778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d878a3d
9367db3
a1a8139
f4248e6
c3358b4
bd6976c
2b8ed13
60cc30c
5490d70
b48045e
4ad0358
f2fbb22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,7 +7,7 @@ | |||||||||||||
|
||||||||||||||
from dspy.adapters.chat_adapter import ChatAdapter | ||||||||||||||
from dspy.adapters.json_adapter import JSONAdapter | ||||||||||||||
from dspy.adapters.types.citation import Citations | ||||||||||||||
from dspy.adapters.types import Type | ||||||||||||||
from dspy.adapters.xml_adapter import XMLAdapter | ||||||||||||||
from dspy.dsp.utils.settings import settings | ||||||||||||||
from dspy.streaming.messages import StreamResponse | ||||||||||||||
|
@@ -102,18 +102,15 @@ def receive(self, chunk: ModelResponseStream): | |||||||||||||
except Exception: | ||||||||||||||
return | ||||||||||||||
|
||||||||||||||
# Handle anthropic citations. see https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api | ||||||||||||||
try: | ||||||||||||||
if self._signature_field_is_citation_type(): | ||||||||||||||
if chunk_citation := chunk.choices[0].delta.provider_specific_fields.get("citation", None): | ||||||||||||||
return StreamResponse( | ||||||||||||||
self.predict_name, | ||||||||||||||
self.signature_field_name, | ||||||||||||||
Citations.from_dict_list([chunk_citation]), | ||||||||||||||
is_last_chunk=False, | ||||||||||||||
) | ||||||||||||||
except Exception: | ||||||||||||||
pass | ||||||||||||||
# Handle custom streamable types | ||||||||||||||
if self._output_type and issubclass(self._output_type, Type) and self._output_type.is_streamable(): | ||||||||||||||
if parsed_chunk := self._output_type.parse_stream_chunk(chunk): | ||||||||||||||
return StreamResponse( | ||||||||||||||
self.predict_name, | ||||||||||||||
self.signature_field_name, | ||||||||||||||
parsed_chunk, | ||||||||||||||
is_last_chunk=self.stream_end, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
if chunk_message and start_identifier in chunk_message: | ||||||||||||||
# If the cache is hit, the chunk_message could be the full response. When it happens we can | ||||||||||||||
|
@@ -217,10 +214,13 @@ def flush(self) -> str: | |||||||||||||
f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}" | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def _signature_field_is_citation_type(self) -> bool: | ||||||||||||||
"""Check if the signature field is a citations field.""" | ||||||||||||||
from dspy.predict import Predict | ||||||||||||||
return isinstance(self.predict, Predict) and getattr(self.predict.signature.output_fields.get(self.signature_field_name, None), "annotation", None) == Citations | ||||||||||||||
@property | ||||||||||||||
def _output_type(self) -> type | None: | ||||||||||||||
try: | ||||||||||||||
return self.predict.signature.output_fields[self.signature_field_name].annotation | ||||||||||||||
except Exception: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when will this throw an exception? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not fail in normal cases, but since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ic, actually I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know, but typing-wise they are nullable ( dspy/dspy/streaming/streaming_listener.py Lines 24 to 29 in 28a7f78
|
||||||||||||||
return None | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
|
||||||||||||||
def find_predictor_for_stream_listeners(program: "Module", stream_listeners: list[StreamListener]): | ||||||||||||||
|
@@ -249,10 +249,10 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis | |||||||||||||
"predictor to use for streaming. Please specify the predictor to listen to." | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
if field_info.annotation not in [str, Citations]: | ||||||||||||||
if not _is_streamable(field_info.annotation): | ||||||||||||||
raise ValueError( | ||||||||||||||
f"Stream listener can only be applied to string or Citations output field, but your field {field_name} is of " | ||||||||||||||
f"type {field_info.annotation}." | ||||||||||||||
f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, " | ||||||||||||||
f"but your field {field_name} is of type {field_info.annotation}." | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
field_name_to_named_predictor[field_name] = (name, predictor) | ||||||||||||||
|
@@ -271,3 +271,12 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis | |||||||||||||
listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name] | ||||||||||||||
predict_id_to_listener[id(listener.predict)].append(listener) | ||||||||||||||
return predict_id_to_listener | ||||||||||||||
|
||||||||||||||
def _is_streamable(field_type: type | None) -> bool: | ||||||||||||||
if field_type is None: | ||||||||||||||
return False | ||||||||||||||
if field_type is str: | ||||||||||||||
return True | ||||||||||||||
if issubclass(field_type, Type): | ||||||||||||||
return field_type.is_streamable() | ||||||||||||||
return False |
Uh oh!
There was an error while loading. Please reload this page.