1
1
import logging
2
2
from copy import deepcopy
3
3
from typing import Any , ClassVar , Self , cast
4
+ from uuid import UUID
4
5
5
6
from aviary .core import (
6
7
Environment ,
23
24
from paperqa .types import PQASession
24
25
from paperqa .utils import get_year
25
26
26
- from .models import QueryRequest
27
27
from .tools import (
28
28
AVAILABLE_TOOL_NAME_TO_CLASS ,
29
29
DEFAULT_TOOL_NAMES ,
@@ -207,25 +207,27 @@ class PaperQAEnvironment(Environment[EnvironmentState]):
207
207
208
208
def __init__ (
209
209
self ,
210
- query : QueryRequest ,
210
+ query : str | MultipleChoiceQuestion ,
211
+ settings : Settings ,
211
212
docs : Docs ,
212
213
llm_model : LiteLLMModel | None = POPULATE_FROM_SETTINGS ,
213
214
summary_llm_model : LiteLLMModel | None = POPULATE_FROM_SETTINGS ,
214
215
embedding_model : EmbeddingModel | None = POPULATE_FROM_SETTINGS ,
216
+ session_id : UUID | None = None ,
215
217
** env_kwargs ,
216
218
):
217
219
super ().__init__ (** env_kwargs )
218
- # Hold onto QueryRequest to create fresh tools and answer during each reset
219
220
self ._query = query
220
- # Hold onto Docs to clear and reuse in state during each reset
221
+ self . _settings = settings
221
222
self ._docs = docs
222
223
self ._llm_model = llm_model
223
224
self ._summary_llm_model = summary_llm_model
224
225
self ._embedding_model = embedding_model
226
+ self ._session_id = session_id
225
227
226
228
def make_tools (self ) -> list [Tool ]:
227
229
return settings_to_tools (
228
- settings = self ._query . settings ,
230
+ settings = self ._settings ,
229
231
llm_model = self ._llm_model ,
230
232
summary_llm_model = self ._summary_llm_model ,
231
233
embedding_model = self ._embedding_model ,
@@ -235,17 +237,23 @@ def make_initial_state(self) -> EnvironmentState:
235
237
status_fn = None
236
238
237
239
if ClinicalTrialsSearch .TOOL_FN_NAME in (
238
- self ._query . settings .agent .tool_names or DEFAULT_TOOL_NAMES
240
+ self ._settings .agent .tool_names or DEFAULT_TOOL_NAMES
239
241
):
240
242
status_fn = clinical_trial_status
241
243
242
- query : str | MultipleChoiceQuestion = self ._query .query
244
+ session_kwargs : dict [str , Any ] = {}
245
+ if self ._session_id :
246
+ session_kwargs ["id" ] = self ._session_id
243
247
return EnvironmentState (
244
248
docs = self ._docs ,
245
249
session = PQASession (
246
- question = query if isinstance (query , str ) else query .question_prompt ,
247
- config_md5 = self ._query .settings .md5 ,
248
- id = self ._query .id ,
250
+ question = (
251
+ self ._query
252
+ if isinstance (self ._query , str )
253
+ else self ._query .question_prompt
254
+ ),
255
+ config_md5 = self ._settings .md5 ,
256
+ ** session_kwargs ,
249
257
),
250
258
status_fn = status_fn ,
251
259
)
@@ -259,7 +267,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
259
267
return (
260
268
[
261
269
Message (
262
- content = self ._query . settings .agent .agent_prompt .format (
270
+ content = self ._settings .agent .agent_prompt .format (
263
271
question = self .state .session .question ,
264
272
status = self .state .status ,
265
273
complete_tool_name = Complete .TOOL_FN_NAME ,
@@ -273,15 +281,15 @@ def export_frame(self) -> Frame:
273
281
return Frame (state = self .state , info = {"query" : self ._query })
274
282
275
283
def _has_excess_answer_failures (self ) -> bool :
276
- if self ._query . settings .answer .max_answer_attempts is None :
284
+ if self ._settings .answer .max_answer_attempts is None :
277
285
return False
278
286
return (
279
287
sum (
280
288
tn == GenerateAnswer .gen_answer .__name__
281
289
for s in self .state .session .tool_history
282
290
for tn in s
283
291
)
284
- > self ._query . settings .answer .max_answer_attempts
292
+ > self ._settings .answer .max_answer_attempts
285
293
)
286
294
287
295
USE_POST_PROCESSED_REWARD : ClassVar [float ] = 0.0
@@ -331,7 +339,8 @@ def __deepcopy__(self, memo) -> Self:
331
339
)
332
340
}
333
341
copy_self = type (self )(
334
- query = deepcopy (self ._query , memo ), # deepcopy for _docs_name
342
+ query = self ._query , # No need to copy since we read only
343
+ settings = deepcopy (self ._settings , memo ), # Deepcopy just to be safe
335
344
docs = copy_state .docs ,
336
345
** env_model_kwargs ,
337
346
)
0 commit comments