Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transcript_parser import parse_podcast_messages

from graphiti_core import Graphiti
from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
Expand Down Expand Up @@ -85,7 +86,12 @@ class LocatedIn(BaseModel):

async def main(use_bulk: bool = False):
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)

# Configure LLM client
llm_config = LLMConfig(model='gpt-4.1-mini', small_model='gpt-4.1-nano')
llm_client = OpenAIClient(config=llm_config)

client = Graphiti(neo4j_uri, neo4j_user, neo4j_password, llm_client=llm_client)
await clear_data(client.driver)
await client.build_indices_and_constraints()
messages = parse_podcast_messages()
Expand Down Expand Up @@ -149,5 +155,9 @@ async def main(use_bulk: bool = False):
saga='Freakonomics Podcast',
)

# Print token usage summary sorted by prompt type
print('\n\nIngestion complete. Token usage by prompt type:')
client.token_tracker.print_summary(sort_by='prompt_name')


asyncio.run(main(False))
12 changes: 12 additions & 0 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,18 @@ def _capture_initialization_telemetry(self):
# Silently handle telemetry errors
pass

@property
def token_tracker(self):
"""Access the LLM client's token usage tracker.

Returns the TokenUsageTracker from the LLM client, which can be used to:
- Get token usage by prompt type: tracker.get_usage()
- Get total token usage: tracker.get_total_usage()
- Print a formatted summary: tracker.print_summary()
- Reset tracking: tracker.reset()
"""
return self.llm_client.token_tracker

def _get_provider_type(self, client) -> str:
"""Get provider type from client class name."""
if client is None:
Expand Down
10 changes: 9 additions & 1 deletion graphiti_core/llm_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,13 @@
from .config import LLMConfig
from .errors import RateLimitError
from .openai_client import OpenAIClient
from .token_tracker import TokenUsage, TokenUsageTracker

__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
__all__ = [
'LLMClient',
'OpenAIClient',
'LLMConfig',
'RateLimitError',
'TokenUsage',
'TokenUsageTracker',
]
24 changes: 19 additions & 5 deletions graphiti_core/llm_client/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async def _generate_response(
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
) -> tuple[dict[str, typing.Any], int, int]:
"""
Generate a response from the Anthropic LLM using tool-based approach for all requests.

Expand All @@ -267,7 +267,7 @@ async def _generate_response(
max_tokens: Maximum number of tokens to generate.

Returns:
Dictionary containing the structured response from the LLM.
Tuple of (response_dict, input_tokens, output_tokens).

Raises:
RateLimitError: If the rate limit is exceeded.
Expand Down Expand Up @@ -295,19 +295,26 @@ async def _generate_response(
tool_choice=tool_choice,
)

# Extract token usage from the response
input_tokens = 0
output_tokens = 0
if hasattr(result, 'usage') and result.usage:
input_tokens = getattr(result.usage, 'input_tokens', 0) or 0
output_tokens = getattr(result.usage, 'output_tokens', 0) or 0

# Extract the tool output from the response
for content_item in result.content:
if content_item.type == 'tool_use':
if isinstance(content_item.input, dict):
tool_args: dict[str, typing.Any] = content_item.input
else:
tool_args = json.loads(str(content_item.input))
return tool_args
return tool_args, input_tokens, output_tokens

# If we didn't get a proper tool_use response, try to extract from text
for content_item in result.content:
if content_item.type == 'text':
return self._extract_json_from_text(content_item.text)
return self._extract_json_from_text(content_item.text), input_tokens, output_tokens
else:
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
Expand Down Expand Up @@ -372,12 +379,19 @@ async def generate_response(
retry_count = 0
max_retries = 2
last_error: Exception | None = None
total_input_tokens = 0
total_output_tokens = 0

while retry_count <= max_retries:
try:
response = await self._generate_response(
response, input_tokens, output_tokens = await self._generate_response(
messages, response_model, max_tokens, model_size
)
total_input_tokens += input_tokens
total_output_tokens += output_tokens

# Record token usage
self.token_tracker.record(prompt_name, total_input_tokens, total_output_tokens)

# If we have a response_model, attempt to validate the response
if response_model is not None:
Expand Down
2 changes: 2 additions & 0 deletions graphiti_core/llm_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..tracer import NoOpTracer, Tracer
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError
from .token_tracker import TokenUsageTracker

DEFAULT_TEMPERATURE = 0
DEFAULT_CACHE_DIR = './llm_cache'
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, config: LLMConfig | None, cache: bool = False):
self.cache_enabled = cache
self.cache_dir = None
self.tracer: Tracer = NoOpTracer()
self.token_tracker: TokenUsageTracker = TokenUsageTracker()

# Only create the cache directory if caching is enabled
if self.cache_enabled:
Expand Down
27 changes: 21 additions & 6 deletions graphiti_core/llm_client/gemini_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ async def _generate_response(
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
) -> tuple[dict[str, typing.Any], int, int]:
"""
Generate a response from the Gemini language model.

Expand All @@ -250,7 +250,7 @@ async def _generate_response(
model_size (ModelSize): The size of the model to use (small or medium).

Returns:
dict[str, typing.Any]: The response from the language model.
tuple[dict[str, typing.Any], int, int]: The response dict, input tokens, and output tokens.

Raises:
RateLimitError: If the API rate limit is exceeded.
Expand Down Expand Up @@ -306,6 +306,13 @@ async def _generate_response(
config=generation_config,
)

# Extract token usage from the response
input_tokens = 0
output_tokens = 0
if hasattr(response, 'usage_metadata') and response.usage_metadata:
input_tokens = getattr(response.usage_metadata, 'prompt_token_count', 0) or 0
output_tokens = getattr(response.usage_metadata, 'candidates_token_count', 0) or 0

# Always capture the raw output for debugging
raw_output = getattr(response, 'text', None)

Expand All @@ -322,7 +329,7 @@ async def _generate_response(
validated_model = response_model.model_validate(json.loads(raw_output))

# Return as a dictionary for API consistency
return validated_model.model_dump()
return validated_model.model_dump(), input_tokens, output_tokens
except Exception as e:
if raw_output:
logger.error(
Expand All @@ -333,11 +340,11 @@ async def _generate_response(
salvaged = self.salvage_json(raw_output)
if salvaged is not None:
logger.warning('Salvaged partial JSON from truncated/malformed output.')
return salvaged
return salvaged, input_tokens, output_tokens
raise Exception(f'Failed to parse structured response: {e}') from e

# Otherwise, return the response text as a dictionary
return {'content': raw_output}
return {'content': raw_output}, input_tokens, output_tokens

except Exception as e:
# Check if it's a rate limit error based on Gemini API error codes
Expand Down Expand Up @@ -394,15 +401,23 @@ async def generate_response(
retry_count = 0
last_error = None
last_output = None
total_input_tokens = 0
total_output_tokens = 0

while retry_count < self.MAX_RETRIES:
try:
response = await self._generate_response(
response, input_tokens, output_tokens = await self._generate_response(
messages=messages,
response_model=response_model,
max_tokens=max_tokens,
model_size=model_size,
)
total_input_tokens += input_tokens
total_output_tokens += output_tokens

# Record token usage
self.token_tracker.record(prompt_name, total_input_tokens, total_output_tokens)

last_output = (
response.get('content')
if isinstance(response, dict) and 'content' in response
Expand Down
61 changes: 49 additions & 12 deletions graphiti_core/llm_client/openai_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,30 +113,59 @@ def _get_model_for_size(self, model_size: ModelSize) -> str:
else:
return self.model or DEFAULT_MODEL

def _handle_structured_response(self, response: Any) -> dict[str, Any]:
"""Handle structured response parsing and validation."""
def _handle_structured_response(
self, response: Any
) -> tuple[dict[str, Any], int, int]:
"""Handle structured response parsing and validation.

Returns:
tuple: (parsed_response, input_tokens, output_tokens)
"""
response_object = response.output_text

# Extract token usage
input_tokens = 0
output_tokens = 0
if hasattr(response, 'usage') and response.usage:
input_tokens = getattr(response.usage, 'input_tokens', 0) or 0
output_tokens = getattr(response.usage, 'output_tokens', 0) or 0

if response_object:
return json.loads(response_object)
elif response_object.refusal:
raise RefusalError(response_object.refusal)
return json.loads(response_object), input_tokens, output_tokens
elif hasattr(response, 'refusal') and response.refusal:
raise RefusalError(response.refusal)
else:
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
raise Exception(f'Invalid response from LLM: {response}')

def _handle_json_response(self, response: Any) -> tuple[dict[str, Any], int, int]:
"""Handle JSON response parsing.

def _handle_json_response(self, response: Any) -> dict[str, Any]:
"""Handle JSON response parsing."""
Returns:
tuple: (parsed_response, input_tokens, output_tokens)
"""
result = response.choices[0].message.content or '{}'
return json.loads(result)

# Extract token usage
input_tokens = 0
output_tokens = 0
if hasattr(response, 'usage') and response.usage:
input_tokens = getattr(response.usage, 'prompt_tokens', 0) or 0
output_tokens = getattr(response.usage, 'completion_tokens', 0) or 0

return json.loads(result), input_tokens, output_tokens

async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, Any]:
"""Generate a response using the appropriate client implementation."""
) -> tuple[dict[str, Any], int, int]:
"""Generate a response using the appropriate client implementation.

Returns:
tuple: (response_dict, input_tokens, output_tokens)
"""
openai_messages = self._convert_messages_to_openai_format(messages)
model = self._get_model_for_size(model_size)

Expand Down Expand Up @@ -210,12 +239,20 @@ async def generate_response(

retry_count = 0
last_error = None
total_input_tokens = 0
total_output_tokens = 0

while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(
response, input_tokens, output_tokens = await self._generate_response(
messages, response_model, max_tokens, model_size
)
total_input_tokens += input_tokens
total_output_tokens += output_tokens

# Record token usage
self.token_tracker.record(prompt_name, total_input_tokens, total_output_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Token usage is recorded even when there's an exception during retry attempts. The total_input_tokens and total_output_tokens are accumulated across retries, but if a retry fails after a successful initial call, the tracker will record tokens from both the successful and failed attempts, potentially double-counting.

Consider moving the token_tracker.record() call outside the retry loop, or only record on the first successful response.


return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries
Expand Down
Loading
Loading