Skip to content

Commit ec4e486

Browse files
authored
Refactor to remove QueryRequest entity (#799)
1 parent 06b0a4c commit ec4e486

File tree

9 files changed

+166
-212
lines changed

9 files changed

+166
-212
lines changed

README.md

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,11 @@ answer_response = ask(
321321
`ask` is just a convenience wrapper around the real entrypoint, which can be accessed if you'd like to run concurrent asynchronous workloads:
322322

323323
```python
324-
from paperqa import Settings, agent_query, QueryRequest
324+
from paperqa import Settings, agent_query
325325

326326
answer_response = await agent_query(
327-
QueryRequest(
328-
query="What manufacturing challenges are unique to bispecific antibodies?",
329-
settings=Settings(temperature=0.5, paper_directory="my_papers"),
330-
)
327+
query="What manufacturing challenges are unique to bispecific antibodies?",
328+
settings=Settings(temperature=0.5, paper_directory="my_papers"),
331329
)
332330
```
333331

@@ -682,7 +680,6 @@ import os
682680

683681
from paperqa import Settings
684682
from paperqa.agents.main import agent_query
685-
from paperqa.agents.models import QueryRequest
686683
from paperqa.agents.search import get_directory_index
687684

688685

@@ -696,15 +693,12 @@ async def amain(folder_of_papers: str | os.PathLike) -> None:
696693

697694
# 2. Use the settings as many times as you want with ask
698695
answer_response_1 = await agent_query(
699-
query=QueryRequest(
700-
query="What is the best way to make a vaccine?", settings=settings
701-
)
696+
query="What is the best way to make a vaccine?",
697+
settings=settings,
702698
)
703699
answer_response_2 = await agent_query(
704-
query=QueryRequest(
705-
query="What manufacturing challenges are unique to bispecific antibodies?",
706-
settings=settings,
707-
)
700+
query="What manufacturing challenges are unique to bispecific antibodies?",
701+
settings=settings,
708702
)
709703
```
710704

@@ -726,15 +720,13 @@ from ldp.agent import SimpleAgent
726720
from ldp.alg.callbacks import MeanMetricsCallback
727721
from ldp.alg.runners import Evaluator, EvaluatorConfig
728722

729-
from paperqa import QueryRequest, Settings
723+
from paperqa import Settings
730724
from paperqa.agents.task import TASK_DATASET_NAME
731725

732726

733727
async def evaluate(folder_of_litqa_v2_papers: str | os.PathLike) -> None:
734-
base_query = QueryRequest(
735-
settings=Settings(paper_directory=folder_of_litqa_v2_papers)
736-
)
737-
dataset = TaskDataset.from_name(TASK_DATASET_NAME, base_query=base_query)
728+
settings = Settings(paper_directory=folder_of_litqa_v2_papers)
729+
dataset = TaskDataset.from_name(TASK_DATASET_NAME, settings=settings)
738730
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)
739731

740732
evaluator = Evaluator(

paperqa/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from paperqa.agents import ask
1616
from paperqa.agents.main import agent_query
17-
from paperqa.agents.models import QueryRequest
1817
from paperqa.docs import Docs, PQASession, print_callback
1918
from paperqa.llms import (
2019
NumpyVectorStore,
@@ -46,7 +45,6 @@
4645
"NumpyVectorStore",
4746
"PQASession",
4847
"QdrantVectorStore",
49-
"QueryRequest",
5048
"SentenceTransformerEmbeddingModel",
5149
"Settings",
5250
"SparseEmbeddingModel",

paperqa/agents/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from paperqa.version import __version__
1616

1717
from .main import agent_query, index_search
18-
from .models import AnswerResponse, QueryRequest
18+
from .models import AnswerResponse
1919
from .search import SearchIndex, get_directory_index
2020

2121
logger = logging.getLogger(__name__)
@@ -102,10 +102,7 @@ def ask(query: str | MultipleChoiceQuestion, settings: Settings) -> AnswerRespon
102102
"""Query PaperQA via an agent."""
103103
configure_cli_logging(settings)
104104
return get_loop().run_until_complete(
105-
agent_query(
106-
QueryRequest(query=query, settings=settings),
107-
agent_type=settings.agent.agent_type,
108-
)
105+
agent_query(query, settings, agent_type=settings.agent.agent_type)
109106
)
110107

111108

paperqa/agents/env.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from copy import deepcopy
33
from typing import Any, ClassVar, Self, cast
4+
from uuid import UUID
45

56
from aviary.core import (
67
Environment,
@@ -23,7 +24,6 @@
2324
from paperqa.types import PQASession
2425
from paperqa.utils import get_year
2526

26-
from .models import QueryRequest
2727
from .tools import (
2828
AVAILABLE_TOOL_NAME_TO_CLASS,
2929
DEFAULT_TOOL_NAMES,
@@ -207,25 +207,27 @@ class PaperQAEnvironment(Environment[EnvironmentState]):
207207

208208
def __init__(
209209
self,
210-
query: QueryRequest,
210+
query: str | MultipleChoiceQuestion,
211+
settings: Settings,
211212
docs: Docs,
212213
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
213214
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
214215
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
216+
session_id: UUID | None = None,
215217
**env_kwargs,
216218
):
217219
super().__init__(**env_kwargs)
218-
# Hold onto QueryRequest to create fresh tools and answer during each reset
219220
self._query = query
220-
# Hold onto Docs to clear and reuse in state during each reset
221+
self._settings = settings
221222
self._docs = docs
222223
self._llm_model = llm_model
223224
self._summary_llm_model = summary_llm_model
224225
self._embedding_model = embedding_model
226+
self._session_id = session_id
225227

226228
def make_tools(self) -> list[Tool]:
227229
return settings_to_tools(
228-
settings=self._query.settings,
230+
settings=self._settings,
229231
llm_model=self._llm_model,
230232
summary_llm_model=self._summary_llm_model,
231233
embedding_model=self._embedding_model,
@@ -235,17 +237,23 @@ def make_initial_state(self) -> EnvironmentState:
235237
status_fn = None
236238

237239
if ClinicalTrialsSearch.TOOL_FN_NAME in (
238-
self._query.settings.agent.tool_names or DEFAULT_TOOL_NAMES
240+
self._settings.agent.tool_names or DEFAULT_TOOL_NAMES
239241
):
240242
status_fn = clinical_trial_status
241243

242-
query: str | MultipleChoiceQuestion = self._query.query
244+
session_kwargs: dict[str, Any] = {}
245+
if self._session_id:
246+
session_kwargs["id"] = self._session_id
243247
return EnvironmentState(
244248
docs=self._docs,
245249
session=PQASession(
246-
question=query if isinstance(query, str) else query.question_prompt,
247-
config_md5=self._query.settings.md5,
248-
id=self._query.id,
250+
question=(
251+
self._query
252+
if isinstance(self._query, str)
253+
else self._query.question_prompt
254+
),
255+
config_md5=self._settings.md5,
256+
**session_kwargs,
249257
),
250258
status_fn=status_fn,
251259
)
@@ -259,7 +267,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
259267
return (
260268
[
261269
Message(
262-
content=self._query.settings.agent.agent_prompt.format(
270+
content=self._settings.agent.agent_prompt.format(
263271
question=self.state.session.question,
264272
status=self.state.status,
265273
complete_tool_name=Complete.TOOL_FN_NAME,
@@ -273,15 +281,15 @@ def export_frame(self) -> Frame:
273281
return Frame(state=self.state, info={"query": self._query})
274282

275283
def _has_excess_answer_failures(self) -> bool:
276-
if self._query.settings.answer.max_answer_attempts is None:
284+
if self._settings.answer.max_answer_attempts is None:
277285
return False
278286
return (
279287
sum(
280288
tn == GenerateAnswer.gen_answer.__name__
281289
for s in self.state.session.tool_history
282290
for tn in s
283291
)
284-
> self._query.settings.answer.max_answer_attempts
292+
> self._settings.answer.max_answer_attempts
285293
)
286294

287295
USE_POST_PROCESSED_REWARD: ClassVar[float] = 0.0
@@ -331,7 +339,8 @@ def __deepcopy__(self, memo) -> Self:
331339
)
332340
}
333341
copy_self = type(self)(
334-
query=deepcopy(self._query, memo), # deepcopy for _docs_name
342+
query=self._query, # No need to copy since we read only
343+
settings=deepcopy(self._settings, memo), # Deepcopy just to be safe
335344
docs=copy_state.docs,
336345
**env_model_kwargs,
337346
)

0 commit comments

Comments
 (0)