Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions extract_thinker/concatenation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _is_valid_json_continuation(self, response: str) -> bool:

return has_json_markers


def handle(self, content: Any, response_model: type[BaseModel], vision: bool = False, extra_content: Optional[str] = None) -> Any:
self.json_parts = []
messages = self._build_messages(content, vision, response_model)
Expand All @@ -38,27 +39,28 @@ def handle(self, content: Any, response_model: type[BaseModel], vision: bool = F
max_retries = 3
while True:
try:
response = self.llm.raw_completion(messages)

# Validate if it's a proper JSON continuation
if not self._is_valid_json_continuation(response):
retry_count += 1
if retry_count >= max_retries:
raise ValueError("Maximum retries reached with invalid JSON continuations")
continue

response_obj = self.llm.raw_completion_complete(messages)
response = response_obj.message.content
finish_reason = response_obj.finish_reason

self.json_parts.append(response)

# Try to process and validate the JSON
result = self._process_json_parts(response_model)
return result


if self._is_finish(response_obj):
result = self._process_json_parts(response_model)
return result

retry_count += 1
if retry_count >= max_retries:
raise ValueError("Maximum retries reached with incomplete response")
messages = self._build_continuation_messages(messages, response)

except ValueError as e:
if retry_count >= max_retries:
raise ValueError(f"Maximum retries reached: {str(e)}")
retry_count += 1
messages = self._build_continuation_messages(messages, response)



def _process_json_parts(self, response_model: type[BaseModel]) -> Any:
"""Process collected JSON parts into a complete response."""
if not self.json_parts:
Expand Down
46 changes: 46 additions & 0 deletions extract_thinker/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,52 @@ def raw_completion(self, messages: List[Dict[str, str]]) -> str:
raw_response = litellm.completion(**params)

return raw_response.choices[0].message.content

def raw_completion_complete(self, messages: List[Dict[str, str]]) -> str:
"""Make raw completion request without response model."""
if self.backend == LLMEngine.PYDANTIC_AI:
# Combine messages into a single prompt
combined_prompt = " ".join([m["content"] for m in messages])
try:
result = asyncio.run(
self.agent.run(
combined_prompt,
result_type=str
)
)
return result.data
except Exception as e:
raise ValueError(f"Failed to extract from source: {str(e)}")

max_tokens = self.DEFAULT_OUTPUT_TOKENS
if self.token_limit is not None:
max_tokens = self.token_limit
elif self.is_thinking:
max_tokens = self.thinking_token_limit

params = {
"model": self.model,
"messages": messages,
"max_completion_tokens": max_tokens,
}

if self.is_thinking:
if litellm.supports_reasoning(self.model):
# Add thinking parameter for supported models
thinking_param = {
"type": "enabled",
"budget_tokens": self.thinking_budget
}
params["thinking"] = thinking_param
else:
print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.")

if self.router:
raw_response = self.router.completion(**params)
else:
raw_response = litellm.completion(**params)

return raw_response.choices[0]

def set_timeout(self, timeout_ms: int) -> None:
"""Set the timeout value for LLM requests in milliseconds."""
Expand Down