diff --git a/backend/chainlit/data/chainlit_data_layer.py b/backend/chainlit/data/chainlit_data_layer.py index 847668b6de..eb194516f7 100644 --- a/backend/chainlit/data/chainlit_data_layer.py +++ b/backend/chainlit/data/chainlit_data_layer.py @@ -568,40 +568,40 @@ async def update_thread( else (metadata.get("name") if metadata and "name" in metadata else None) ) - data = { - "id": thread_id, - "name": thread_name, - "userId": user_id, - "tags": tags, - "metadata": json.dumps(metadata or {}), - "updatedAt": datetime.now(), - } - - # Remove None values - data = {k: v for k, v in data.items() if v is not None} - - # Build the query dynamically based on available fields - columns = [f'"{k}"' for k in data.keys()] - placeholders = [f"${i + 1}" for i in range(len(data))] - values = list(data.values()) - - update_sets = [f'"{k}" = EXCLUDED."{k}"' for k in data.keys() if k != "id"] - - if update_sets: - query = f""" - INSERT INTO "Thread" ({", ".join(columns)}) - VALUES ({", ".join(placeholders)}) - ON CONFLICT (id) DO UPDATE - SET {", ".join(update_sets)}; + now = datetime.now() + + # Convert metadata to JSONB if provided + metadata_jsonb = json.dumps(metadata) if metadata is not None else None + + # Build atomic UPSERT with PostgreSQL-native JSON merging + if metadata_jsonb is not None: + # When metadata is provided, use PostgreSQL's || operator for atomic merging + query = """ + INSERT INTO "Thread" (id, name, "userId", metadata, "updatedAt", "createdAt") + VALUES ($1, $2, $3, $4, $5, $5) + ON CONFLICT (id) DO UPDATE SET + name = COALESCE(EXCLUDED.name, "Thread".name), + "userId" = COALESCE(EXCLUDED."userId", "Thread"."userId"), + metadata = COALESCE("Thread".metadata::jsonb, '{}'::jsonb) || EXCLUDED.metadata::jsonb, + "updatedAt" = EXCLUDED."updatedAt" """ + params = [thread_id, thread_name, user_id, metadata_jsonb, now] else: - query = f""" - INSERT INTO "Thread" ({", ".join(columns)}) - VALUES ({", ".join(placeholders)}) - ON CONFLICT (id) DO NOTHING + # When metadata is None, ensure metadata gets default empty JSON + query = """ + INSERT INTO "Thread" (id, name, "userId", metadata, "updatedAt", "createdAt") + VALUES ($1, $2, $3, '{}'::jsonb, $4, $4) + ON CONFLICT (id) DO UPDATE SET + name = COALESCE(EXCLUDED.name, "Thread".name), + "userId" = COALESCE(EXCLUDED."userId", "Thread"."userId"), + "updatedAt" = EXCLUDED."updatedAt" """ + params = [thread_id, thread_name, user_id, now] + + # Create params dict with numbered keys for execute_query + numbered_params = {str(i + 1): value for i, value in enumerate(params)} - await self.execute_query(query, {str(i + 1): v for i, v in enumerate(values)}) + await self.execute_query(query, numbered_params) def _extract_feedback_dict_from_step_row(self, row: Dict) -> Optional[FeedbackDict]: if row["feedback_id"] is not None: diff --git a/backend/tests/data/test_chainlit_data_layer.py b/backend/tests/data/test_chainlit_data_layer.py new file mode 100644 index 0000000000..357fa70e23 --- /dev/null +++ b/backend/tests/data/test_chainlit_data_layer.py @@ -0,0 +1,219 @@ +"""Tests for ChainlitDataLayer metadata handling.""" + +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from chainlit.data.chainlit_data_layer import ChainlitDataLayer + + +@pytest.fixture +def mock_pool(): + """Mock asyncpg connection pool.""" + return AsyncMock() + + +@pytest.fixture +def data_layer(mock_pool): + """Create ChainlitDataLayer instance with mocked pool.""" + layer = ChainlitDataLayer("postgresql://test", show_logger=False) + layer.pool = mock_pool + return layer + + +@pytest.mark.asyncio +async def test_update_thread_metadata_none_preserves_existing(data_layer): + """Test that passing metadata=None doesn't overwrite existing metadata.""" + thread_id = "test_thread_123" + existing_metadata = {"key1": "value1", "key2": "value2"} + + # Mock get_thread to return existing thread with metadata + mock_thread = {"id": thread_id, "metadata": json.dumps(existing_metadata)} + + with patch.object(data_layer, "get_thread", return_value=mock_thread): + with patch.object(data_layer, "execute_query", return_value=[]) as mock_execute: + # Call update_thread with metadata=None (default) + await data_layer.update_thread(thread_id=thread_id, name="Test Thread") + + # Verify execute_query was called + assert mock_execute.called + + # Get the query and parameters that were passed + call_args = mock_execute.call_args + query = call_args[0][0] + + # The query should NOT include metadata field when metadata=None + assert "metadata" not in query.lower() + + +@pytest.mark.asyncio +async def test_update_thread_metadata_merges_with_existing(data_layer): + """Test that new metadata is sent to PostgreSQL for atomic merging.""" + thread_id = "test_thread_456" + new_metadata = {"key1": "new_value", "key3": "added"} + + # Capture the actual SQL parameters that would be passed + captured_params = {} + captured_query = "" + + def capture_execute_query(query: str, params=None): + nonlocal captured_params, captured_query + captured_params = params or {} + captured_query = query + return [] + + with patch.object( + data_layer, "execute_query", side_effect=capture_execute_query + ) as mock_execute: + # Call update_thread with new metadata + await data_layer.update_thread(thread_id=thread_id, metadata=new_metadata) + + # Verify execute_query was called + assert mock_execute.called + + # Verify the query uses PostgreSQL's atomic JSON merge operator + assert "|| EXCLUDED.metadata::jsonb" in captured_query + + # Verify only the new metadata is sent (not pre-merged) + metadata_json = None + for key, value in captured_params.items(): + if isinstance(value, str) and (value.startswith("{") or value == "{}"): + metadata_json = value + break + + assert metadata_json is not None, ( + f"No JSON metadata found in params: {captured_params}" + ) + metadata_param = json.loads(metadata_json) + + # Verify only the new metadata is sent to PostgreSQL (merging happens DB-side) + assert metadata_param == new_metadata + + +@pytest.mark.asyncio +async def test_update_thread_metadata_empty_dict_updates(data_layer): + """Test that passing empty dict {} gets sent to PostgreSQL for atomic merging.""" + thread_id = "test_thread_789" + new_metadata = {} + + # Capture the actual SQL parameters that would be passed + captured_params = {} + captured_query = "" + + def capture_execute_query(query: str, params=None): + nonlocal captured_params, captured_query + captured_params = params or {} + captured_query = query + return [] + + with patch.object( + data_layer, "execute_query", side_effect=capture_execute_query + ) as mock_execute: + # Call update_thread with empty metadata dict + await data_layer.update_thread(thread_id=thread_id, metadata=new_metadata) + + # Verify execute_query was called + assert mock_execute.called + + # Verify the query uses PostgreSQL's atomic JSON merge operator + assert "|| EXCLUDED.metadata::jsonb" in captured_query + + # Find the metadata parameter + metadata_json = None + for key, value in captured_params.items(): + if isinstance(value, str) and (value.startswith("{") or value == "{}"): + metadata_json = value + break + + assert metadata_json is not None, ( + f"No JSON metadata found in params: {captured_params}" + ) + metadata_param = json.loads(metadata_json) + + # Verify the empty dict is sent to PostgreSQL (which will merge with existing) + assert metadata_param == new_metadata + + +@pytest.mark.asyncio +async def test_update_thread_metadata_no_existing_thread(data_layer): + """Test that metadata works correctly when no existing thread.""" + thread_id = "test_thread_new" + new_metadata = {"key1": "value1", "key2": "value2"} + + # Capture the actual SQL parameters that would be passed + captured_params = {} + + def capture_execute_query(query: str, params=None): + nonlocal captured_params + captured_params = params or {} + return [] + + # Mock get_thread to return None (no existing thread) + with patch.object(data_layer, "get_thread", return_value=None): + with patch.object( + data_layer, "execute_query", side_effect=capture_execute_query + ) as mock_execute: + # Call update_thread with new metadata + await data_layer.update_thread(thread_id=thread_id, metadata=new_metadata) + + # Verify execute_query was called + assert mock_execute.called + + # Find the metadata parameter + metadata_json = None + for key, value in captured_params.items(): + if isinstance(value, str) and (value.startswith("{") or value == "{}"): + metadata_json = value + break + + assert metadata_json is not None, ( + f"No JSON metadata found in params: {captured_params}" + ) + metadata_param = json.loads(metadata_json) + + # Verify the new metadata was stored as-is + assert metadata_param == new_metadata + + +@pytest.mark.asyncio +async def test_update_thread_metadata_corrupted_existing_fallback(data_layer): + """Test fallback behavior when existing metadata is corrupted.""" + thread_id = "test_thread_corrupted" + new_metadata = {"key1": "value1"} + + # Mock get_thread to return thread with corrupted metadata + mock_thread = {"id": thread_id, "metadata": "invalid_json_string"} + + # Capture the actual SQL parameters that would be passed + captured_params = {} + + def capture_execute_query(query: str, params=None): + nonlocal captured_params + captured_params = params or {} + return [] + + with patch.object(data_layer, "get_thread", return_value=mock_thread): + with patch.object( + data_layer, "execute_query", side_effect=capture_execute_query + ) as mock_execute: + # Call update_thread with new metadata + await data_layer.update_thread(thread_id=thread_id, metadata=new_metadata) + + # Verify execute_query was called + assert mock_execute.called + + # Find the metadata parameter + metadata_json = None + for key, value in captured_params.items(): + if isinstance(value, str) and (value.startswith("{") or value == "{}"): + metadata_json = value + break + + assert metadata_json is not None, ( + f"No JSON metadata found in params: {captured_params}" + ) + metadata_param = json.loads(metadata_json) + + # Should fallback to just the new metadata + assert metadata_param == new_metadata