Skip to content

Commit a200ff2

Browse files
prasmussen15claude
andauthored
feat: simplify extraction pipeline and add batch entity summarization (#1224)
* feat(llm): add token usage tracking for LLM calls Add TokenUsageTracker class to track input/output tokens by prompt type during LLM calls. This helps analyze token costs across different operations like extract_nodes, extract_edges, resolve_nodes, etc. Changes: - Add graphiti_core/llm_client/token_tracker.py with TokenUsageTracker - Update LLMClient base class to include token_tracker instance - Update OpenAI base client to capture and record token usage - Add token_tracker property on Graphiti class for easy access - Update podcast_runner.py to print token usage summary after ingestion Usage: client = Graphiti(...) # ... run ingestion ... client.token_tracker.print_summary(sort_by='prompt_name') Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * chore: temporarily disable summary early return optimization Disable the optimization that skips LLM calls when node summary + edge facts is under 2000 characters. This forces all summaries to be generated via LLM for token usage analysis. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Revert "chore: temporarily disable summary early return optimization" This reverts the summary optimization changes. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat: simplify extraction pipeline and add batch entity summarization - Remove chunking code for entity-dense episodes (node_operations.py) - Delete _extract_nodes_chunked, _extract_from_chunk, _merge_extracted_entities - Always use single LLM call for entity extraction - Remove chunking code for edge extraction (edge_operations.py) - Remove MAX_NODES constant and generate_covering_chunks usage - Process all nodes in single LLM call instead of covering subsets - Add batch entity summarization (node_operations.py, extract_nodes.py) - New SummarizedEntity and SummarizedEntities Pydantic models - New extract_summaries_batch prompt for batch processing - New _extract_entity_summaries_batch function - Nodes with short summaries get edge facts appended directly (no LLM) - Only nodes needing LLM summarization are batched together - Simplify edge attribute extraction (extract_edges.py, edge_operations.py) - Remove episode_content from context (attributes from fact only) - Keep reference_time for temporal resolution - Add existing_attributes to preserve/update existing values - Improve edge deduplication prompt (dedupe_edges.py, edge_operations.py) - Use continuous indexing across duplicate and invalidation candidates - Deduplicate invalidation candidates against duplicate candidates - Allow EXISTING FACTS to be both duplicates AND contradicted - Consolidate to single contradicted_facts field - Remove obsolete chunking tests (test_entity_extraction.py) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * chore: bump version to 0.27.2pre1 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add token tracking for Anthropic/Gemini clients and missing tests - Implement token tracking in AnthropicClient._generate_response() and generate_response() using result.usage.input_tokens/output_tokens - Implement token tracking in GeminiClient._generate_response() and generate_response() using response.usage_metadata - Add comprehensive unit tests for TokenUsageTracker class - Add tests for _extract_entity_summaries_batch function covering: - No nodes needing summarization - Short summaries with edge facts - Long summaries requiring LLM - Node filter (should_summarize_node) - Batch multiple nodes - Unknown entity handling - Missing episode and summary Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Update test_node_operations.py for batch summarization API - Remove import of extract_attributes_from_node (function was removed) - Add import of _extract_entity_summaries_batch - Update tests to use new batch summarization API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add MAX_NODES limit for batch entity summarization - Add MAX_NODES = 30 constant - Partition nodes needing summarization into flights of MAX_NODES - Extract _process_summary_flight helper for processing each flight - Each flight makes a separate LLM call to avoid context overflow Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Change default OpenAI models to gpt-5-mini Update both DEFAULT_MODEL and DEFAULT_SMALL_MODEL to use gpt-5-mini. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Update podcast_runner.py to use default OpenAI models Remove explicit model configuration to use the default gpt-5-mini models from OpenAIClient. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Revert default model changes to gpt-4.1-mini/nano Restore the original default models instead of gpt-5-mini. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Address PR review comments - Fix unreachable code in _handle_structured_response (check response.refusal) - Process node summary flights in parallel using semaphore_gather - Use case-insensitive name matching for LLM summary responses - Handle duplicate node names by applying summary to all matching nodes - Fix edge case when both edge lists are empty in contradiction processing - Fix potential AttributeError when episode is None in edge attributes - Add tests for flight partitioning and case-insensitive name matching Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent fe19482 commit a200ff2

File tree

17 files changed

+1108
-605
lines changed

17 files changed

+1108
-605
lines changed

examples/podcast/podcast_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transcript_parser import parse_podcast_messages
2626

2727
from graphiti_core import Graphiti
28+
from graphiti_core.llm_client import LLMConfig, OpenAIClient
2829
from graphiti_core.nodes import EpisodeType
2930
from graphiti_core.utils.bulk_utils import RawEpisode
3031
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
@@ -85,7 +86,12 @@ class LocatedIn(BaseModel):
8586

8687
async def main(use_bulk: bool = False):
8788
setup_logging()
88-
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
89+
90+
# Configure LLM client
91+
llm_config = LLMConfig(model='gpt-4.1-mini', small_model='gpt-4.1-nano')
92+
llm_client = OpenAIClient(config=llm_config)
93+
94+
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password, llm_client=llm_client)
8995
await clear_data(client.driver)
9096
await client.build_indices_and_constraints()
9197
messages = parse_podcast_messages()
@@ -149,5 +155,9 @@ async def main(use_bulk: bool = False):
149155
saga='Freakonomics Podcast',
150156
)
151157

158+
# Print token usage summary sorted by prompt type
159+
print('\n\nIngestion complete. Token usage by prompt type:')
160+
client.token_tracker.print_summary(sort_by='prompt_name')
161+
152162

153163
asyncio.run(main(False))

graphiti_core/graphiti.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,18 @@ def _capture_initialization_telemetry(self):
260260
# Silently handle telemetry errors
261261
pass
262262

263+
@property
264+
def token_tracker(self):
265+
"""Access the LLM client's token usage tracker.
266+
267+
Returns the TokenUsageTracker from the LLM client, which can be used to:
268+
- Get token usage by prompt type: tracker.get_usage()
269+
- Get total token usage: tracker.get_total_usage()
270+
- Print a formatted summary: tracker.print_summary()
271+
- Reset tracking: tracker.reset()
272+
"""
273+
return self.llm_client.token_tracker
274+
263275
def _get_provider_type(self, client) -> str:
264276
"""Get provider type from client class name."""
265277
if client is None:

graphiti_core/llm_client/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,13 @@
1818
from .config import LLMConfig
1919
from .errors import RateLimitError
2020
from .openai_client import OpenAIClient
21+
from .token_tracker import TokenUsage, TokenUsageTracker
2122

22-
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
23+
__all__ = [
24+
'LLMClient',
25+
'OpenAIClient',
26+
'LLMConfig',
27+
'RateLimitError',
28+
'TokenUsage',
29+
'TokenUsageTracker',
30+
]

graphiti_core/llm_client/anthropic_client.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def _generate_response(
257257
response_model: type[BaseModel] | None = None,
258258
max_tokens: int | None = None,
259259
model_size: ModelSize = ModelSize.medium,
260-
) -> dict[str, typing.Any]:
260+
) -> tuple[dict[str, typing.Any], int, int]:
261261
"""
262262
Generate a response from the Anthropic LLM using tool-based approach for all requests.
263263
@@ -267,7 +267,7 @@ async def _generate_response(
267267
max_tokens: Maximum number of tokens to generate.
268268
269269
Returns:
270-
Dictionary containing the structured response from the LLM.
270+
Tuple of (response_dict, input_tokens, output_tokens).
271271
272272
Raises:
273273
RateLimitError: If the rate limit is exceeded.
@@ -295,19 +295,26 @@ async def _generate_response(
295295
tool_choice=tool_choice,
296296
)
297297

298+
# Extract token usage from the response
299+
input_tokens = 0
300+
output_tokens = 0
301+
if hasattr(result, 'usage') and result.usage:
302+
input_tokens = getattr(result.usage, 'input_tokens', 0) or 0
303+
output_tokens = getattr(result.usage, 'output_tokens', 0) or 0
304+
298305
# Extract the tool output from the response
299306
for content_item in result.content:
300307
if content_item.type == 'tool_use':
301308
if isinstance(content_item.input, dict):
302309
tool_args: dict[str, typing.Any] = content_item.input
303310
else:
304311
tool_args = json.loads(str(content_item.input))
305-
return tool_args
312+
return tool_args, input_tokens, output_tokens
306313

307314
# If we didn't get a proper tool_use response, try to extract from text
308315
for content_item in result.content:
309316
if content_item.type == 'text':
310-
return self._extract_json_from_text(content_item.text)
317+
return self._extract_json_from_text(content_item.text), input_tokens, output_tokens
311318
else:
312319
raise ValueError(
313320
f'Could not extract structured data from model response: {result.content}'
@@ -372,12 +379,19 @@ async def generate_response(
372379
retry_count = 0
373380
max_retries = 2
374381
last_error: Exception | None = None
382+
total_input_tokens = 0
383+
total_output_tokens = 0
375384

376385
while retry_count <= max_retries:
377386
try:
378-
response = await self._generate_response(
387+
response, input_tokens, output_tokens = await self._generate_response(
379388
messages, response_model, max_tokens, model_size
380389
)
390+
total_input_tokens += input_tokens
391+
total_output_tokens += output_tokens
392+
393+
# Record token usage
394+
self.token_tracker.record(prompt_name, total_input_tokens, total_output_tokens)
381395

382396
# If we have a response_model, attempt to validate the response
383397
if response_model is not None:

graphiti_core/llm_client/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ..tracer import NoOpTracer, Tracer
3030
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
3131
from .errors import RateLimitError
32+
from .token_tracker import TokenUsageTracker
3233

3334
DEFAULT_TEMPERATURE = 0
3435
DEFAULT_CACHE_DIR = './llm_cache'
@@ -80,6 +81,7 @@ def __init__(self, config: LLMConfig | None, cache: bool = False):
8081
self.cache_enabled = cache
8182
self.cache_dir = None
8283
self.tracer: Tracer = NoOpTracer()
84+
self.token_tracker: TokenUsageTracker = TokenUsageTracker()
8385

8486
# Only create the cache directory if caching is enabled
8587
if self.cache_enabled:

graphiti_core/llm_client/gemini_client.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ async def _generate_response(
239239
response_model: type[BaseModel] | None = None,
240240
max_tokens: int | None = None,
241241
model_size: ModelSize = ModelSize.medium,
242-
) -> dict[str, typing.Any]:
242+
) -> tuple[dict[str, typing.Any], int, int]:
243243
"""
244244
Generate a response from the Gemini language model.
245245
@@ -250,7 +250,7 @@ async def _generate_response(
250250
model_size (ModelSize): The size of the model to use (small or medium).
251251
252252
Returns:
253-
dict[str, typing.Any]: The response from the language model.
253+
tuple[dict[str, typing.Any], int, int]: The response dict, input tokens, and output tokens.
254254
255255
Raises:
256256
RateLimitError: If the API rate limit is exceeded.
@@ -306,6 +306,13 @@ async def _generate_response(
306306
config=generation_config,
307307
)
308308

309+
# Extract token usage from the response
310+
input_tokens = 0
311+
output_tokens = 0
312+
if hasattr(response, 'usage_metadata') and response.usage_metadata:
313+
input_tokens = getattr(response.usage_metadata, 'prompt_token_count', 0) or 0
314+
output_tokens = getattr(response.usage_metadata, 'candidates_token_count', 0) or 0
315+
309316
# Always capture the raw output for debugging
310317
raw_output = getattr(response, 'text', None)
311318

@@ -322,7 +329,7 @@ async def _generate_response(
322329
validated_model = response_model.model_validate(json.loads(raw_output))
323330

324331
# Return as a dictionary for API consistency
325-
return validated_model.model_dump()
332+
return validated_model.model_dump(), input_tokens, output_tokens
326333
except Exception as e:
327334
if raw_output:
328335
logger.error(
@@ -333,11 +340,11 @@ async def _generate_response(
333340
salvaged = self.salvage_json(raw_output)
334341
if salvaged is not None:
335342
logger.warning('Salvaged partial JSON from truncated/malformed output.')
336-
return salvaged
343+
return salvaged, input_tokens, output_tokens
337344
raise Exception(f'Failed to parse structured response: {e}') from e
338345

339346
# Otherwise, return the response text as a dictionary
340-
return {'content': raw_output}
347+
return {'content': raw_output}, input_tokens, output_tokens
341348

342349
except Exception as e:
343350
# Check if it's a rate limit error based on Gemini API error codes
@@ -394,15 +401,23 @@ async def generate_response(
394401
retry_count = 0
395402
last_error = None
396403
last_output = None
404+
total_input_tokens = 0
405+
total_output_tokens = 0
397406

398407
while retry_count < self.MAX_RETRIES:
399408
try:
400-
response = await self._generate_response(
409+
response, input_tokens, output_tokens = await self._generate_response(
401410
messages=messages,
402411
response_model=response_model,
403412
max_tokens=max_tokens,
404413
model_size=model_size,
405414
)
415+
total_input_tokens += input_tokens
416+
total_output_tokens += output_tokens
417+
418+
# Record token usage
419+
self.token_tracker.record(prompt_name, total_input_tokens, total_output_tokens)
420+
406421
last_output = (
407422
response.get('content')
408423
if isinstance(response, dict) and 'content' in response

graphiti_core/llm_client/openai_base_client.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,30 +113,59 @@ def _get_model_for_size(self, model_size: ModelSize) -> str:
113113
else:
114114
return self.model or DEFAULT_MODEL
115115

116-
def _handle_structured_response(self, response: Any) -> dict[str, Any]:
117-
"""Handle structured response parsing and validation."""
116+
def _handle_structured_response(
117+
self, response: Any
118+
) -> tuple[dict[str, Any], int, int]:
119+
"""Handle structured response parsing and validation.
120+
121+
Returns:
122+
tuple: (parsed_response, input_tokens, output_tokens)
123+
"""
118124
response_object = response.output_text
119125

126+
# Extract token usage
127+
input_tokens = 0
128+
output_tokens = 0
129+
if hasattr(response, 'usage') and response.usage:
130+
input_tokens = getattr(response.usage, 'input_tokens', 0) or 0
131+
output_tokens = getattr(response.usage, 'output_tokens', 0) or 0
132+
120133
if response_object:
121-
return json.loads(response_object)
122-
elif response_object.refusal:
123-
raise RefusalError(response_object.refusal)
134+
return json.loads(response_object), input_tokens, output_tokens
135+
elif hasattr(response, 'refusal') and response.refusal:
136+
raise RefusalError(response.refusal)
124137
else:
125-
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
138+
raise Exception(f'Invalid response from LLM: {response}')
139+
140+
def _handle_json_response(self, response: Any) -> tuple[dict[str, Any], int, int]:
141+
"""Handle JSON response parsing.
126142
127-
def _handle_json_response(self, response: Any) -> dict[str, Any]:
128-
"""Handle JSON response parsing."""
143+
Returns:
144+
tuple: (parsed_response, input_tokens, output_tokens)
145+
"""
129146
result = response.choices[0].message.content or '{}'
130-
return json.loads(result)
147+
148+
# Extract token usage
149+
input_tokens = 0
150+
output_tokens = 0
151+
if hasattr(response, 'usage') and response.usage:
152+
input_tokens = getattr(response.usage, 'prompt_tokens', 0) or 0
153+
output_tokens = getattr(response.usage, 'completion_tokens', 0) or 0
154+
155+
return json.loads(result), input_tokens, output_tokens
131156

132157
async def _generate_response(
133158
self,
134159
messages: list[Message],
135160
response_model: type[BaseModel] | None = None,
136161
max_tokens: int = DEFAULT_MAX_TOKENS,
137162
model_size: ModelSize = ModelSize.medium,
138-
) -> dict[str, Any]:
139-
"""Generate a response using the appropriate client implementation."""
163+
) -> tuple[dict[str, Any], int, int]:
164+
"""Generate a response using the appropriate client implementation.
165+
166+
Returns:
167+
tuple: (response_dict, input_tokens, output_tokens)
168+
"""
140169
openai_messages = self._convert_messages_to_openai_format(messages)
141170
model = self._get_model_for_size(model_size)
142171

@@ -210,12 +239,20 @@ async def generate_response(
210239

211240
retry_count = 0
212241
last_error = None
242+
total_input_tokens = 0
243+
total_output_tokens = 0
213244

214245
while retry_count <= self.MAX_RETRIES:
215246
try:
216-
response = await self._generate_response(
247+
response, input_tokens, output_tokens = await self._generate_response(
217248
messages, response_model, max_tokens, model_size
218249
)
250+
total_input_tokens += input_tokens
251+
total_output_tokens += output_tokens
252+
253+
# Record token usage
254+
self.token_tracker.record(prompt_name, total_input_tokens, total_output_tokens)
255+
219256
return response
220257
except (RateLimitError, RefusalError):
221258
# These errors should not trigger retries

0 commit comments

Comments
 (0)