4
4
import json_repair
5
5
import litellm
6
6
7
- from dspy .adapters .types import History
7
+ from dspy .adapters .types import History , Type
8
8
from dspy .adapters .types .base_type import split_message_content_for_custom_types
9
9
from dspy .adapters .types .tool import Tool , ToolCalls
10
10
from dspy .experimental import Citations
16
16
if TYPE_CHECKING :
17
17
from dspy .clients .lm import LM
18
18
19
+ _DEFAULT_NATIVE_RESPONSE_TYPES = [Citations ]
19
20
20
21
class Adapter :
21
- def __init__ (self , callbacks : list [BaseCallback ] | None = None , use_native_function_calling : bool = False ):
22
+ def __init__ (self , callbacks : list [BaseCallback ] | None = None , use_native_function_calling : bool = False , native_response_types : list [ type [ Type ]] | None = None ):
22
23
self .callbacks = callbacks or []
23
24
self .use_native_function_calling = use_native_function_calling
25
+ self .native_response_types = native_response_types or _DEFAULT_NATIVE_RESPONSE_TYPES
24
26
25
27
def __init_subclass__ (cls , ** kwargs ) -> None :
26
28
super ().__init_subclass__ (** kwargs )
@@ -64,9 +66,10 @@ def _call_preprocess(
64
66
65
67
return signature_for_native_function_calling
66
68
67
- citation_output_field_name = self ._get_citation_output_field_name (signature )
68
- if citation_output_field_name :
69
- signature = signature .delete (citation_output_field_name )
69
+ # Handle custom types that use native response
70
+ for name , field in signature .output_fields .items ():
71
+ if isinstance (field .annotation , type ) and issubclass (field .annotation , Type ) and field .annotation in self .native_response_types :
72
+ signature = signature .delete (name )
70
73
71
74
return signature
72
75
@@ -75,23 +78,21 @@ def _call_postprocess(
75
78
processed_signature : type [Signature ],
76
79
original_signature : type [Signature ],
77
80
outputs : list [dict [str , Any ]],
81
+ lm : "LM" ,
78
82
) -> list [dict [str , Any ]]:
79
83
values = []
80
84
81
85
tool_call_output_field_name = self ._get_tool_call_output_field_name (original_signature )
82
- citation_output_field_name = self ._get_citation_output_field_name (original_signature )
83
86
84
87
for output in outputs :
85
88
output_logprobs = None
86
89
tool_calls = None
87
- citations = None
88
90
text = output
89
91
90
92
if isinstance (output , dict ):
91
93
text = output ["text" ]
92
94
output_logprobs = output .get ("logprobs" )
93
95
tool_calls = output .get ("tool_calls" )
94
- citations = output .get ("citations" )
95
96
96
97
if text :
97
98
value = self .parse (processed_signature , text )
@@ -114,9 +115,10 @@ def _call_postprocess(
114
115
]
115
116
value [tool_call_output_field_name ] = ToolCalls .from_dict_list (tool_calls )
116
117
117
- if citations and citation_output_field_name :
118
- citations_obj = Citations .from_dict_list (citations )
119
- value [citation_output_field_name ] = citations_obj
118
+ # Parse custom types that does not rely on the adapter parsing
119
+ for name , field in original_signature .output_fields .items ():
120
+ if isinstance (field .annotation , type ) and issubclass (field .annotation , Type ) and field .annotation in self .native_response_types :
121
+ value [name ] = field .annotation .parse_lm_response (output )
120
122
121
123
if output_logprobs :
122
124
value ["logprobs" ] = output_logprobs
@@ -137,7 +139,7 @@ def __call__(
137
139
inputs = self .format (processed_signature , demos , inputs )
138
140
139
141
outputs = lm (messages = inputs , ** lm_kwargs )
140
- return self ._call_postprocess (processed_signature , signature , outputs )
142
+ return self ._call_postprocess (processed_signature , signature , outputs , lm )
141
143
142
144
async def acall (
143
145
self ,
@@ -151,7 +153,7 @@ async def acall(
151
153
inputs = self .format (processed_signature , demos , inputs )
152
154
153
155
outputs = await lm .acall (messages = inputs , ** lm_kwargs )
154
- return self ._call_postprocess (processed_signature , signature , outputs )
156
+ return self ._call_postprocess (processed_signature , signature , outputs , lm )
155
157
156
158
def format (
157
159
self ,
@@ -402,12 +404,6 @@ def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool:
402
404
return name
403
405
return None
404
406
405
- def _get_citation_output_field_name (self , signature : type [Signature ]) -> str | None :
406
- """Find the Citations output field in the signature."""
407
- for name , field in signature .output_fields .items ():
408
- if field .annotation == Citations :
409
- return name
410
- return None
411
407
412
408
def format_conversation_history (
413
409
self ,
0 commit comments