Skip to content

Commit 0afa795

Browse files
authored
fix: correct state race condition (#270)
* fix: only persist data once Signed-off-by: Samantha Coyle <[email protected]> * style: rm debug log Signed-off-by: Samantha Coyle <[email protected]> * style: appease linter Signed-off-by: Samantha Coyle <[email protected]> * fix: also prevent duplicate tool msgs upon constructing chat history Signed-off-by: Samantha Coyle <[email protected]> * fix: actually address race condition on parallel tool result state saves and use mem as signle source of truth if available Signed-off-by: Samantha Coyle <[email protected]> * fix: always load state, add mem msg atomic and always save msgs to mem Signed-off-by: Samantha Coyle <[email protected]> * style: lint fix Signed-off-by: Samantha Coyle <[email protected]> * fix: appease flake* Signed-off-by: Samantha Coyle <[email protected]> * style: last fix for flake8 Signed-off-by: Samantha Coyle <[email protected]> * fix: updates for tests Signed-off-by: Samantha Coyle <[email protected]> * style: make lint happy Signed-off-by: Samantha Coyle <[email protected]> * fix: bring changes to correct branch ugh Signed-off-by: Samantha Coyle <[email protected]> * fix: i hate the linter Signed-off-by: Samantha Coyle <[email protected]> --------- Signed-off-by: Samantha Coyle <[email protected]>
1 parent 8d63690 commit 0afa795

File tree

6 files changed

+166
-86
lines changed

6 files changed

+166
-86
lines changed

dapr_agents/agents/base.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def _get_last_user_message(
512512
# ------------------------------------------------------------------
513513
# State-aware message helpers (use AgentComponents' state model)
514514
# ------------------------------------------------------------------
515-
def _construct_messages_with_instance_history(
515+
def _reconstruct_conversation_history(
516516
self, instance_id: str
517517
) -> List[Dict[str, Any]]:
518518
"""
@@ -546,10 +546,11 @@ def _construct_messages_with_instance_history(
546546
except Exception: # noqa: BLE001
547547
logger.debug("Unable to load persistent memory.", exc_info=True)
548548

549-
history: List[Dict[str, Any]] = []
550-
history.extend(persistent_memory)
551-
history.extend(instance_messages)
552-
return history
549+
# Persistent conversation history in the memory config is the single source of truth for conversation history
550+
if persistent_memory:
551+
return persistent_memory
552+
# Note: this is just ot make tests happy for now and in reality for durable agent this is not used for app resumption of state
553+
return instance_messages
553554

554555
def _sync_system_messages_with_state(
555556
self,
@@ -585,23 +586,22 @@ def _process_user_message(
585586

586587
container = self._get_entry_container()
587588
entry = container.get(instance_id) if container else None
588-
if entry is None or not hasattr(entry, "messages"):
589-
return
590-
591-
# Use configured coercer / message model
592-
message_model = (
593-
self._message_coercer(user_message_copy) # type: ignore[attr-defined]
594-
if getattr(self, "_message_coercer", None)
595-
else self._message_dict_to_message_model(user_message_copy)
596-
)
597-
entry.messages.append(message_model) # type: ignore[attr-defined]
598-
if hasattr(entry, "last_message"):
599-
entry.last_message = message_model # type: ignore[attr-defined]
589+
if entry is not None and hasattr(entry, "messages"):
590+
# Use configured coercer / message model
591+
message_model = (
592+
self._message_coercer(user_message_copy) # type: ignore[attr-defined]
593+
if getattr(self, "_message_coercer", None)
594+
else self._message_dict_to_message_model(user_message_copy)
595+
)
596+
entry.messages.append(message_model) # type: ignore[attr-defined]
597+
if hasattr(entry, "last_message"):
598+
entry.last_message = message_model # type: ignore[attr-defined]
600599

601-
session_id = getattr(getattr(self, "memory", None), "session_id", None)
602-
if session_id is not None and hasattr(entry, "session_id"):
603-
entry.session_id = str(session_id) # type: ignore[attr-defined]
600+
session_id = getattr(getattr(self, "memory", None), "session_id", None)
601+
if session_id is not None and hasattr(entry, "session_id"):
602+
entry.session_id = str(session_id) # type: ignore[attr-defined]
604603

604+
# Always add to memory (required for chat history for agent durability upon restarts)
605605
self.memory.add_message(
606606
UserMessage(content=user_message_copy.get("content", ""))
607607
)
@@ -621,24 +621,25 @@ def _save_assistant_message(
621621

622622
container = self._get_entry_container()
623623
entry = container.get(instance_id) if container else None
624-
if entry is None or not hasattr(entry, "messages"):
625-
return
626-
627-
message_id = assistant_message.get("id")
628-
if message_id and any(
629-
getattr(msg, "id", None) == message_id for msg in getattr(entry, "messages")
630-
):
631-
return
632-
633-
message_model = (
634-
self._message_coercer(assistant_message) # type: ignore[attr-defined]
635-
if getattr(self, "_message_coercer", None)
636-
else self._message_dict_to_message_model(assistant_message)
637-
)
638-
entry.messages.append(message_model) # type: ignore[attr-defined]
639-
if hasattr(entry, "last_message"):
640-
entry.last_message = message_model # type: ignore[attr-defined]
624+
if entry is not None and hasattr(entry, "messages"):
625+
message_id = assistant_message.get("id")
626+
if message_id and any(
627+
getattr(msg, "id", None) == message_id
628+
for msg in getattr(entry, "messages")
629+
):
630+
# Duplicate in state - skip state update but still add to memory
631+
pass
632+
else:
633+
message_model = (
634+
self._message_coercer(assistant_message) # type: ignore[attr-defined]
635+
if getattr(self, "_message_coercer", None)
636+
else self._message_dict_to_message_model(assistant_message)
637+
)
638+
entry.messages.append(message_model) # type: ignore[attr-defined]
639+
if hasattr(entry, "last_message"):
640+
entry.last_message = message_model # type: ignore[attr-defined]
641641

642+
# Always add to memory (required for chat history)
642643
self.memory.add_message(AssistantMessage(**assistant_message))
643644
self.save_state()
644645

dapr_agents/agents/durable.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict):
179179
source = metadata.get("source") or "direct"
180180

181181
# Ensure we have the latest durable state for this turn.
182-
self.load_state()
182+
if self.state_store:
183+
self.load_state()
183184

184185
# Bootstrap instance entry (flexible to non-`instances` models).
185186
self.ensure_instance_exists(
@@ -369,6 +370,10 @@ def record_initial_entry(
369370
- start_time: ISO8601 datetime string.
370371
- trace_context: Optional tracing context.
371372
"""
373+
# Load latest state to ensure we have current data before modifying
374+
if self.state_store:
375+
self.load_state()
376+
372377
instance_id = payload.get("instance_id")
373378
trace_context = payload.get("trace_context")
374379
input_value = payload.get("input_value", "Triggered without input.")
@@ -418,10 +423,14 @@ def call_llm(
418423
Raises:
419424
AgentError: If the LLM call fails or yields no message.
420425
"""
426+
# Load latest state to ensure we have current data
427+
if self.state_store:
428+
self.load_state()
429+
421430
instance_id = payload.get("instance_id")
422431
task = payload.get("task")
423432

424-
chat_history = self._construct_messages_with_instance_history(instance_id)
433+
chat_history = self._reconstruct_conversation_history(instance_id)
425434
messages = self.prompting_helper.build_initial_messages(
426435
user_input=task,
427436
chat_history=chat_history,
@@ -481,6 +490,10 @@ def run_tool(
481490
Raises:
482491
AgentError: If tool arguments contain invalid JSON.
483492
"""
493+
# Load latest state to ensure we have current data before modifying
494+
if self.state_store:
495+
self.load_state()
496+
484497
tool_call = payload.get("tool_call", {})
485498
instance_id = payload.get("instance_id")
486499
fn_name = tool_call["function"]["name"]
@@ -548,8 +561,27 @@ async def _execute_tool() -> Any:
548561
if hasattr(entry, "last_message"):
549562
entry.last_message = tool_message_model
550563

551-
# Always persist to memory + in-process tool history
552-
self.memory.add_message(tool_message)
564+
tool_call_id = agent_message["tool_call_id"]
565+
# Check if tool message already exists in memory
566+
existing_memory_messages = self.memory.get_messages()
567+
tool_exists_in_memory = False
568+
for mem_msg in existing_memory_messages:
569+
msg_dict = (
570+
mem_msg.model_dump()
571+
if hasattr(mem_msg, "model_dump")
572+
else (mem_msg if isinstance(mem_msg, dict) else {})
573+
)
574+
if (
575+
msg_dict.get("role") == "tool"
576+
and msg_dict.get("tool_call_id") == tool_call_id
577+
):
578+
tool_exists_in_memory = True
579+
break
580+
581+
# Only add to persistent memory if not already present
582+
if not tool_exists_in_memory:
583+
self.memory.add_message(tool_message)
584+
553585
self.tool_history.append(history_entry)
554586

555587
# Print the tool result for visibility
@@ -647,6 +679,10 @@ def finalize_workflow(
647679
payload: Dict with 'instance_id', 'final_output', 'end_time',
648680
and optional 'triggering_workflow_instance_id'.
649681
"""
682+
# Load latest state to ensure we have current data before modifying
683+
if self.state_store:
684+
self.load_state()
685+
650686
instance_id = payload.get("instance_id")
651687
final_output = payload.get("final_output", "")
652688
end_time = payload.get("end_time", "")

dapr_agents/agents/standalone.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ async def _run_agent(
175175
active_instance = instance_id or self._generate_instance_id()
176176

177177
# Build initial messages with persistent + per-instance history
178-
chat_history = self._construct_messages_with_instance_history(active_instance)
178+
chat_history = self._reconstruct_conversation_history(active_instance)
179179
messages = self.prompting_helper.build_initial_messages(
180180
user_input=input_data,
181181
chat_history=chat_history,
@@ -236,7 +236,7 @@ def construct_messages(
236236
"""
237237
self.load_state()
238238
active_instance = instance_id or self._generate_instance_id()
239-
chat_history = self._construct_messages_with_instance_history(active_instance)
239+
chat_history = self._reconstruct_conversation_history(active_instance)
240240
return self.prompting_helper.build_initial_messages(
241241
user_input=input_data,
242242
chat_history=chat_history,

dapr_agents/memory/daprstatestore.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,56 @@ def add_message(self, message: Union[Dict[str, Any], BaseMessage]) -> None:
8585
message (Union[Dict[str, Any], BaseMessage]): The message to add to the memory.
8686
"""
8787
message = self._convert_to_dict(message)
88-
message_id = str(uuid.uuid4())
89-
message_key = self._get_message_key(message_id)
9088
message.update(
9189
{
9290
"createdAt": datetime.now().isoformat() + "Z",
9391
}
9492
)
95-
existing = self.get_messages()
96-
existing.append(message)
97-
logger.debug(
98-
f"Adding message {message} with key {message_key} to session {self.session_id}"
99-
)
100-
self.dapr_store.save_state(
101-
self.session_id,
102-
json.dumps(existing),
103-
state_metadata={"contentType": "application/json"},
104-
)
93+
94+
# Retry loop for optimistic concurrency control
95+
# TODO: make this nicer in future, but for durability this must all be atomic
96+
max_attempts = 10
97+
for attempt in range(1, max_attempts + 1):
98+
try:
99+
response = self.dapr_store.get_state(
100+
self.session_id,
101+
state_metadata={"contentType": "application/json"},
102+
)
103+
104+
if response and response.data:
105+
existing = json.loads(response.data)
106+
etag = response.etag
107+
else:
108+
existing = []
109+
etag = None
110+
111+
existing.append(message)
112+
# Save with etag - will fail if someone else modified it
113+
self.dapr_store.save_state(
114+
self.session_id,
115+
json.dumps(existing),
116+
state_metadata={"contentType": "application/json"},
117+
etag=etag,
118+
)
119+
120+
# Success - exit retry loop
121+
return
122+
123+
except Exception as exc:
124+
if attempt == max_attempts:
125+
logger.exception(
126+
f"Failed to add message to session {self.session_id} after {max_attempts} attempts: {exc}"
127+
)
128+
raise
129+
else:
130+
logger.warning(
131+
f"Conflict adding message to session {self.session_id} (attempt {attempt}/{max_attempts}): {exc}, retrying..."
132+
)
133+
# Brief exponential backoff with jitter
134+
import time
135+
import random
136+
137+
time.sleep(min(0.1 * attempt, 0.5) * (1 + random.uniform(0, 0.25)))
105138

106139
def add_messages(self, messages: List[Union[Dict[str, Any], BaseMessage]]) -> None:
107140
"""

tests/agents/durableagent/test_durable_agent.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,9 @@ async def test_finish_workflow_activity(self, basic_durable_agent):
431431
# Mock the activity context and save_state
432432
mock_ctx = Mock()
433433

434-
with patch.object(basic_durable_agent, "save_state"):
434+
with patch.object(basic_durable_agent, "save_state"), patch.object(
435+
basic_durable_agent, "load_state"
436+
):
435437
basic_durable_agent.finalize_workflow(
436438
mock_ctx,
437439
{
@@ -487,7 +489,9 @@ def test_run_tool(self, basic_durable_agent, mock_tool):
487489
# Mock the activity context and save_state
488490
mock_ctx = Mock()
489491

490-
with patch.object(basic_durable_agent, "save_state"):
492+
with patch.object(basic_durable_agent, "save_state"), patch.object(
493+
basic_durable_agent, "load_state"
494+
):
491495
result = basic_durable_agent.run_tool(
492496
mock_ctx,
493497
{
@@ -741,7 +745,9 @@ def test_create_tool_message_objects(self, basic_durable_agent):
741745

742746
mock_ctx = Mock()
743747

744-
with patch.object(basic_durable_agent, "save_state"):
748+
with patch.object(basic_durable_agent, "save_state"), patch.object(
749+
basic_durable_agent, "load_state"
750+
):
745751
result = basic_durable_agent.run_tool(
746752
mock_ctx,
747753
{
@@ -805,7 +811,9 @@ def test_tool_func(x):
805811
)
806812

807813
# Mock save_state to prevent actual persistence
808-
with patch.object(basic_durable_agent, "save_state"):
814+
with patch.object(basic_durable_agent, "save_state"), patch.object(
815+
basic_durable_agent, "load_state"
816+
):
809817
mock_ctx = Mock()
810818

811819
# Call run_tool activity which appends messages and tool_history
@@ -892,8 +900,8 @@ def test_tool_func(x: str) -> str:
892900
assert basic_durable_agent.tool_history[0].tool_call_id == "call_123"
893901
assert basic_durable_agent.tool_history[0].tool_name == "TestToolFunc"
894902

895-
def test_construct_messages_with_instance_history(self, basic_durable_agent):
896-
"""Test _construct_messages_with_instance_history helper method."""
903+
def test_reconstruct_conversation_history(self, basic_durable_agent):
904+
"""Test test_reconstruct_conversation_history helper method."""
897905
from datetime import datetime, timezone
898906

899907
instance_id = "test-instance-123"
@@ -918,9 +926,7 @@ def test_construct_messages_with_instance_history(self, basic_durable_agent):
918926
start_time=datetime.now(timezone.utc),
919927
)
920928

921-
messages = basic_durable_agent._construct_messages_with_instance_history(
922-
instance_id
923-
)
929+
messages = basic_durable_agent._reconstruct_conversation_history(instance_id)
924930

925931
# Should include messages from instance history (system messages excluded from instance timeline)
926932
# Plus any messages from memory

0 commit comments

Comments
 (0)