Skip to content

Commit 10fb579

Browse files
authored
Allow custom type to be streamed and use native response (#8778)
* Allow custom dspy type to be streamed * nit * Add custom type parsing logic * simplify * update type * update type * add lm check * comment * move native response config into adapter * fix test * comment
1 parent 28a7f78 commit 10fb579

File tree

6 files changed

+182
-44
lines changed

6 files changed

+182
-44
lines changed

dspy/adapters/base.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json_repair
55
import litellm
66

7-
from dspy.adapters.types import History
7+
from dspy.adapters.types import History, Type
88
from dspy.adapters.types.base_type import split_message_content_for_custom_types
99
from dspy.adapters.types.tool import Tool, ToolCalls
1010
from dspy.experimental import Citations
@@ -16,11 +16,13 @@
1616
if TYPE_CHECKING:
1717
from dspy.clients.lm import LM
1818

19+
_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations]
1920

2021
class Adapter:
21-
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False):
22+
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False, native_response_types: list[type[Type]] | None = None):
2223
self.callbacks = callbacks or []
2324
self.use_native_function_calling = use_native_function_calling
25+
self.native_response_types = native_response_types or _DEFAULT_NATIVE_RESPONSE_TYPES
2426

2527
def __init_subclass__(cls, **kwargs) -> None:
2628
super().__init_subclass__(**kwargs)
@@ -64,9 +66,10 @@ def _call_preprocess(
6466

6567
return signature_for_native_function_calling
6668

67-
citation_output_field_name = self._get_citation_output_field_name(signature)
68-
if citation_output_field_name:
69-
signature = signature.delete(citation_output_field_name)
69+
# Handle custom types that use native response
70+
for name, field in signature.output_fields.items():
71+
if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types:
72+
signature = signature.delete(name)
7073

7174
return signature
7275

@@ -75,23 +78,21 @@ def _call_postprocess(
7578
processed_signature: type[Signature],
7679
original_signature: type[Signature],
7780
outputs: list[dict[str, Any]],
81+
lm: "LM",
7882
) -> list[dict[str, Any]]:
7983
values = []
8084

8185
tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature)
82-
citation_output_field_name = self._get_citation_output_field_name(original_signature)
8386

8487
for output in outputs:
8588
output_logprobs = None
8689
tool_calls = None
87-
citations = None
8890
text = output
8991

9092
if isinstance(output, dict):
9193
text = output["text"]
9294
output_logprobs = output.get("logprobs")
9395
tool_calls = output.get("tool_calls")
94-
citations = output.get("citations")
9596

9697
if text:
9798
value = self.parse(processed_signature, text)
@@ -114,9 +115,10 @@ def _call_postprocess(
114115
]
115116
value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls)
116117

117-
if citations and citation_output_field_name:
118-
citations_obj = Citations.from_dict_list(citations)
119-
value[citation_output_field_name] = citations_obj
118+
# Parse custom types that does not rely on the adapter parsing
119+
for name, field in original_signature.output_fields.items():
120+
if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types:
121+
value[name] = field.annotation.parse_lm_response(output)
120122

121123
if output_logprobs:
122124
value["logprobs"] = output_logprobs
@@ -137,7 +139,7 @@ def __call__(
137139
inputs = self.format(processed_signature, demos, inputs)
138140

139141
outputs = lm(messages=inputs, **lm_kwargs)
140-
return self._call_postprocess(processed_signature, signature, outputs)
142+
return self._call_postprocess(processed_signature, signature, outputs, lm)
141143

142144
async def acall(
143145
self,
@@ -151,7 +153,7 @@ async def acall(
151153
inputs = self.format(processed_signature, demos, inputs)
152154

153155
outputs = await lm.acall(messages=inputs, **lm_kwargs)
154-
return self._call_postprocess(processed_signature, signature, outputs)
156+
return self._call_postprocess(processed_signature, signature, outputs, lm)
155157

156158
def format(
157159
self,
@@ -402,12 +404,6 @@ def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool:
402404
return name
403405
return None
404406

405-
def _get_citation_output_field_name(self, signature: type[Signature]) -> str | None:
406-
"""Find the Citations output field in the signature."""
407-
for name, field in signature.output_fields.items():
408-
if field.annotation == Citations:
409-
return name
410-
return None
411407

412408
def format_conversation_history(
413409
self,

dspy/adapters/types/base_type.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
22
import re
3-
from typing import Any, get_args, get_origin
3+
from typing import Any, Optional, get_args, get_origin
44

55
import json_repair
66
import pydantic
7+
from litellm import ModelResponseStream
78

89
CUSTOM_TYPE_START_IDENTIFIER = "<<CUSTOM-TYPE-START-IDENTIFIER>>"
910
CUSTOM_TYPE_END_IDENTIFIER = "<<CUSTOM-TYPE-END-IDENTIFIER>>"
@@ -69,6 +70,36 @@ def serialize_model(self):
6970
)
7071
return formatted
7172

73+
@classmethod
74+
def is_streamable(cls) -> bool:
75+
"""Whether the custom type is streamable."""
76+
return False
77+
78+
@classmethod
79+
def parse_stream_chunk(cls, chunk: ModelResponseStream) -> Optional["Type"]:
80+
"""
81+
Parse a stream chunk into the custom type.
82+
83+
Args:
84+
chunk: A stream chunk.
85+
86+
Returns:
87+
A custom type object or None if the chunk is not for this custom type.
88+
"""
89+
return None
90+
91+
92+
@classmethod
93+
def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Type"]:
94+
"""Parse a LM response into the custom type.
95+
96+
Args:
97+
response: A LM response.
98+
99+
Returns:
100+
A custom type object.
101+
"""
102+
return None
72103

73104
def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
74105
"""Split user message content into a list of content blocks.

dspy/adapters/types/citation.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

33
import pydantic
44

@@ -166,3 +166,51 @@ def __len__(self):
166166
def __getitem__(self, index):
167167
"""Allow indexing into citations."""
168168
return self.citations[index]
169+
170+
@classmethod
171+
def is_streamable(cls) -> bool:
172+
"""Whether the Citations type is streamable."""
173+
return True
174+
175+
@classmethod
176+
def parse_stream_chunk(cls, chunk) -> Optional["Citations"]:
177+
"""
178+
Parse a stream chunk into Citations.
179+
180+
Args:
181+
chunk: A stream chunk from the LM.
182+
183+
Returns:
184+
A Citations object if the chunk contains citation data, None otherwise.
185+
"""
186+
try:
187+
# Check if the chunk has citation data in provider_specific_fields
188+
if hasattr(chunk, "choices") and chunk.choices:
189+
delta = chunk.choices[0].delta
190+
if hasattr(delta, "provider_specific_fields") and delta.provider_specific_fields:
191+
citation_data = delta.provider_specific_fields.get("citation")
192+
if citation_data:
193+
return cls.from_dict_list([citation_data])
194+
except Exception:
195+
pass
196+
return None
197+
198+
199+
@classmethod
200+
def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Citations"]:
201+
"""Parse a LM response into Citations.
202+
203+
Args:
204+
response: A LM response that may contain citation data.
205+
206+
Returns:
207+
A Citations object if citation data is found, None otherwise.
208+
"""
209+
if isinstance(response, dict):
210+
# Check if the response contains citations in the expected format
211+
if "citations" in response:
212+
citations_data = response["citations"]
213+
if isinstance(citations_data, list):
214+
return cls.from_dict_list(citations_data)
215+
216+
return None

dspy/streaming/streaming_listener.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dspy.adapters.chat_adapter import ChatAdapter
99
from dspy.adapters.json_adapter import JSONAdapter
10-
from dspy.adapters.types.citation import Citations
10+
from dspy.adapters.types import Type
1111
from dspy.adapters.xml_adapter import XMLAdapter
1212
from dspy.dsp.utils.settings import settings
1313
from dspy.streaming.messages import StreamResponse
@@ -102,18 +102,15 @@ def receive(self, chunk: ModelResponseStream):
102102
except Exception:
103103
return
104104

105-
# Handle anthropic citations. see https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api
106-
try:
107-
if self._signature_field_is_citation_type():
108-
if chunk_citation := chunk.choices[0].delta.provider_specific_fields.get("citation", None):
109-
return StreamResponse(
110-
self.predict_name,
111-
self.signature_field_name,
112-
Citations.from_dict_list([chunk_citation]),
113-
is_last_chunk=False,
114-
)
115-
except Exception:
116-
pass
105+
# Handle custom streamable types
106+
if self._output_type and issubclass(self._output_type, Type) and self._output_type.is_streamable():
107+
if parsed_chunk := self._output_type.parse_stream_chunk(chunk):
108+
return StreamResponse(
109+
self.predict_name,
110+
self.signature_field_name,
111+
parsed_chunk,
112+
is_last_chunk=self.stream_end,
113+
)
117114

118115
if chunk_message and start_identifier in chunk_message:
119116
# If the cache is hit, the chunk_message could be the full response. When it happens we can
@@ -217,10 +214,13 @@ def flush(self) -> str:
217214
f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}"
218215
)
219216

220-
def _signature_field_is_citation_type(self) -> bool:
221-
"""Check if the signature field is a citations field."""
222-
from dspy.predict import Predict
223-
return isinstance(self.predict, Predict) and getattr(self.predict.signature.output_fields.get(self.signature_field_name, None), "annotation", None) == Citations
217+
@property
218+
def _output_type(self) -> type | None:
219+
try:
220+
return self.predict.signature.output_fields[self.signature_field_name].annotation
221+
except Exception:
222+
return None
223+
224224

225225

226226
def find_predictor_for_stream_listeners(program: "Module", stream_listeners: list[StreamListener]):
@@ -249,10 +249,10 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis
249249
"predictor to use for streaming. Please specify the predictor to listen to."
250250
)
251251

252-
if field_info.annotation not in [str, Citations]:
252+
if not _is_streamable(field_info.annotation):
253253
raise ValueError(
254-
f"Stream listener can only be applied to string or Citations output field, but your field {field_name} is of "
255-
f"type {field_info.annotation}."
254+
f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, "
255+
f"but your field {field_name} is of type {field_info.annotation}."
256256
)
257257

258258
field_name_to_named_predictor[field_name] = (name, predictor)
@@ -271,3 +271,12 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis
271271
listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name]
272272
predict_id_to_listener[id(listener.predict)].append(listener)
273273
return predict_id_to_listener
274+
275+
def _is_streamable(field_type: type | None) -> bool:
276+
if field_type is None:
277+
return False
278+
if field_type is str:
279+
return True
280+
if issubclass(field_type, Type):
281+
return field_type.is_streamable()
282+
return False

tests/adapters/test_citation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class CitationSignature(Signature):
135135
answer: str = dspy.OutputField()
136136
citations: Citations = dspy.OutputField()
137137

138-
adapter = ChatAdapter()
138+
adapter = ChatAdapter(native_response_types=[Citations])
139139

140140
outputs = [{
141141
"text": "[[ ## answer ## ]]\nThe answer is blue.\n\n[[ ## citations ## ]]\n[]",
@@ -154,7 +154,8 @@ class CitationSignature(Signature):
154154
result = adapter._call_postprocess(
155155
CitationSignature.delete("citations"),
156156
CitationSignature,
157-
outputs
157+
outputs,
158+
dspy.LM(model="claude-3-5-sonnet-20241022")
158159
)
159160

160161
assert len(result) == 1

tests/streaming/test_streaming.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
1010

1111
import dspy
12+
from dspy.adapters.types import Type
1213
from dspy.experimental import Citations, Document
1314
from dspy.streaming import StatusMessage, StatusMessageProvider, streaming_response
1415

@@ -877,6 +878,58 @@ async def send_to_stream():
877878
assert isinstance(all_chunks[1], dspy.Prediction)
878879

879880

881+
@pytest.mark.anyio
882+
async def test_streaming_allows_custom_streamable_type():
883+
class CustomType(Type):
884+
message: str
885+
886+
@classmethod
887+
def is_streamable(cls) -> bool:
888+
return True
889+
890+
@classmethod
891+
def parse_stream_chunk(cls, chunk):
892+
return CustomType(message=chunk.choices[0].delta.content)
893+
894+
@classmethod
895+
def parse_lm_response(cls, response: dict) -> "CustomType":
896+
return CustomType(message=response.split("\n\n")[0])
897+
898+
class CustomSignature(dspy.Signature):
899+
question: str = dspy.InputField()
900+
answer: CustomType = dspy.OutputField()
901+
902+
program = dspy.streamify(
903+
dspy.Predict(CustomSignature),
904+
stream_listeners=[
905+
dspy.streaming.StreamListener(signature_field_name="answer"),
906+
],
907+
)
908+
909+
async def stream(*args, **kwargs):
910+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="Hello"))])
911+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="World"))])
912+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n"))])
913+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
914+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))])
915+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
916+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])
917+
918+
919+
with mock.patch("litellm.acompletion", side_effect=stream):
920+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter(native_response_types=[CustomType])):
921+
output = program(question="why did a chicken cross the kitchen?")
922+
all_chunks = []
923+
async for value in output:
924+
if isinstance(value, dspy.streaming.StreamResponse):
925+
all_chunks.append(value)
926+
elif isinstance(value, dspy.Prediction):
927+
assert isinstance(value.answer, CustomType)
928+
assert value.answer.message == "HelloWorld"
929+
930+
assert all(isinstance(chunk.chunk, CustomType) for chunk in all_chunks)
931+
932+
880933
@pytest.mark.anyio
881934
async def test_streaming_with_citations():
882935
class AnswerWithSources(dspy.Signature):
@@ -936,7 +989,7 @@ async def citation_stream(*args, **kwargs):
936989
# Create test documents
937990
docs = [Document(data="Water boils at 100°C at standard pressure.", title="Physics Facts")]
938991

939-
with dspy.context(lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False)):
992+
with dspy.context(lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False), adapter=dspy.ChatAdapter(native_response_types=[Citations])):
940993
output = program(documents=docs, question="What temperature does water boil?")
941994
citation_chunks = []
942995
final_prediction = None

0 commit comments

Comments
 (0)