From f46230ba00fc52490cdb263ce85cef32b88ef9df Mon Sep 17 00:00:00 2001 From: Rauf Akdemir Date: Mon, 29 Dec 2025 21:06:39 +0100 Subject: [PATCH 1/9] feat(sync): add multiplexer module for destination migrations - Add SyncMultiplexer for managing multiple destinations per sync - Implement fork/switch/resync operations for blue-green deployments - Add ARFReplaySource for replaying entities from raw data store - Refactor SyncFactory into modular builders (_source, _destination, _context, _pipeline) - Add DestinationRole enum (ACTIVE, SHADOW, DEPRECATED) to SyncConnection - Add feature flag SYNC_MULTIPLEXER for gating access - Add CRUD layer for SyncConnection with role-based filtering - Add API endpoints for multiplex operations --- backend/airweave/api/v1/api.py | 2 + .../api/v1/endpoints/sync_multiplex.py | 206 +++++ backend/airweave/core/shared_models.py | 1 + backend/airweave/crud/__init__.py | 2 + backend/airweave/crud/crud_sync_connection.py | 285 ++++++ backend/airweave/models/__init__.py | 3 +- backend/airweave/models/sync_connection.py | 25 +- backend/airweave/platform/sync/__init__.py | 8 +- backend/airweave/platform/sync/factory.py | 829 ------------------ .../platform/sync/factory/__init__.py | 17 + .../platform/sync/factory/_context.py | 243 +++++ .../platform/sync/factory/_destination.py | 203 +++++ .../platform/sync/factory/_factory.py | 180 ++++ .../platform/sync/factory/_pipeline.py | 119 +++ .../airweave/platform/sync/factory/_source.py | 287 ++++++ .../platform/sync/multiplex/__init__.py | 29 + .../platform/sync/multiplex/multiplexer.py | 403 +++++++++ .../platform/sync/multiplex/replay.py | 246 ++++++ backend/airweave/schemas/__init__.py | 16 + backend/airweave/schemas/sync_connection.py | 89 ++ .../versions/add_role_to_sync_connection.py | 38 + 21 files changed, 2398 insertions(+), 833 deletions(-) create mode 100644 backend/airweave/api/v1/endpoints/sync_multiplex.py create mode 100644 backend/airweave/crud/crud_sync_connection.py delete mode 100644 backend/airweave/platform/sync/factory.py create mode 100644 backend/airweave/platform/sync/factory/__init__.py create mode 100644 backend/airweave/platform/sync/factory/_context.py create mode 100644 backend/airweave/platform/sync/factory/_destination.py create mode 100644 backend/airweave/platform/sync/factory/_factory.py create mode 100644 backend/airweave/platform/sync/factory/_pipeline.py create mode 100644 backend/airweave/platform/sync/factory/_source.py create mode 100644 backend/airweave/platform/sync/multiplex/__init__.py create mode 100644 backend/airweave/platform/sync/multiplex/multiplexer.py create mode 100644 backend/airweave/platform/sync/multiplex/replay.py create mode 100644 backend/airweave/schemas/sync_connection.py create mode 100644 backend/alembic/versions/add_role_to_sync_connection.py diff --git a/backend/airweave/api/v1/api.py b/backend/airweave/api/v1/api.py index afbdb243c..ccd94e495 100644 --- a/backend/airweave/api/v1/api.py +++ b/backend/airweave/api/v1/api.py @@ -22,6 +22,7 @@ source_rate_limits, sources, sync, + sync_multiplex, transformers, usage, users, @@ -52,6 +53,7 @@ source_rate_limits.router, prefix="/source-rate-limits", tags=["source-rate-limits"] ) api_router.include_router(sync.router, prefix="/sync", tags=["sync"]) +api_router.include_router(sync_multiplex.router, prefix="/sync", tags=["sync-multiplex"]) api_router.include_router(entities.router, prefix="/entities", tags=["entities"]) api_router.include_router(entity_counts.router, prefix="/entity-counts", tags=["entity-counts"]) api_router.include_router(transformers.router, prefix="/transformers", tags=["transformers"]) diff --git a/backend/airweave/api/v1/endpoints/sync_multiplex.py b/backend/airweave/api/v1/endpoints/sync_multiplex.py new file mode 100644 index 000000000..fdd7d11f5 --- /dev/null +++ b/backend/airweave/api/v1/endpoints/sync_multiplex.py @@ -0,0 +1,206 @@ +"""Endpoints for sync multiplexing (destination migrations). + +Enables blue-green deployments for vector DB migrations: +- Fork: Add shadow destination + optionally replay from ARF +- Switch: Promote shadow to active +- List: Show all destinations with roles +- Resync: Force full sync from source to refresh ARF + +Feature-gated: Requires SYNC_MULTIPLEXER feature flag enabled for the organization. +""" + +from typing import List +from uuid import UUID + +from fastapi import Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import schemas +from airweave.api import deps +from airweave.api.context import ApiContext +from airweave.api.router import TrailingSlashRouter +from airweave.core.shared_models import FeatureFlag +from airweave.platform.sync.multiplex.multiplexer import SyncMultiplexer + +router = TrailingSlashRouter() + + +def _require_multiplexer_feature(ctx: ApiContext) -> None: + """Check if organization has multiplexer feature enabled. + + Args: + ctx: API context + + Raises: + HTTPException: If feature not enabled + """ + if not ctx.has_feature(FeatureFlag.SYNC_MULTIPLEXER): + raise HTTPException( + status_code=403, + detail="Sync multiplexer feature is not enabled for this organization", + ) + + +@router.get( + "/{sync_id}/destinations", + response_model=List[schemas.DestinationSlotInfo], + summary="List destination slots", + description="List all destinations for a sync with their roles (active/shadow/deprecated).", +) +async def list_destinations( + sync_id: UUID, + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), +) -> List[schemas.DestinationSlotInfo]: + """List all destination slots for a sync. + + Returns slots sorted by role: ACTIVE first, then SHADOW, then DEPRECATED. + """ + _require_multiplexer_feature(ctx) + multiplexer = SyncMultiplexer(db, ctx, ctx.logger) + return await multiplexer.list_destinations(sync_id) + + +@router.post( + "/{sync_id}/destinations/fork", + response_model=schemas.ForkDestinationResponse, + summary="Fork a new destination", + description="Add a shadow destination for migration testing. Optionally replay from ARF store.", +) +async def fork_destination( + sync_id: UUID, + request: schemas.ForkDestinationRequest, + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), +) -> schemas.ForkDestinationResponse: + """Fork a new shadow destination. + + Creates a new destination slot with SHADOW role. If replay_from_arf is True, + entities will be replayed from the ARF store to populate the new destination. + + Args: + sync_id: Sync ID to fork destination for + request: Fork request with destination connection ID and replay flag + db: Database session + ctx: API context + + Returns: + ForkDestinationResponse with slot and optional replay job info + """ + _require_multiplexer_feature(ctx) + multiplexer = SyncMultiplexer(db, ctx, ctx.logger) + slot, replay_job = await multiplexer.fork( + sync_id=sync_id, + destination_connection_id=request.destination_connection_id, + replay_from_arf=request.replay_from_arf, + ) + + slot_schema = schemas.SyncConnectionSchema( + id=slot.id, + sync_id=slot.sync_id, + connection_id=slot.connection_id, + role=slot.role, + created_at=slot.created_at, + modified_at=slot.modified_at, + ) + + return schemas.ForkDestinationResponse( + slot=slot_schema, + replay_job_id=replay_job.id if replay_job else None, + replay_job_status=replay_job.status.value if replay_job else None, + ) + + +@router.post( + "/{sync_id}/destinations/{slot_id}/switch", + response_model=schemas.SwitchDestinationResponse, + summary="Switch active destination", + description="Promote a shadow destination to active. The current active becomes deprecated.", +) +async def switch_destination( + sync_id: UUID, + slot_id: UUID, + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), +) -> schemas.SwitchDestinationResponse: + """Switch the active destination. + + Promotes the specified shadow slot to ACTIVE and demotes the current + ACTIVE slot to DEPRECATED. + + Args: + sync_id: Sync ID + slot_id: Slot ID to promote to active + db: Database session + ctx: API context + + Returns: + Switch response with new and previous active slot IDs + """ + _require_multiplexer_feature(ctx) + multiplexer = SyncMultiplexer(db, ctx, ctx.logger) + return await multiplexer.switch(sync_id=sync_id, new_active_slot_id=slot_id) + + +@router.post( + "/{sync_id}/resync", + response_model=schemas.SyncJob, + summary="Resync from source", + description="Force a full sync from the source to refresh the ARF store.", +) +async def resync_from_source( + sync_id: UUID, + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), +) -> schemas.SyncJob: + """Force full sync from source to refresh ARF. + + Triggers a full sync (ignoring cursor) to ensure the ARF store is up-to-date + before forking to a new destination. + + Args: + sync_id: Sync ID + db: Database session + ctx: API context + + Returns: + SyncJob for tracking progress + """ + _require_multiplexer_feature(ctx) + multiplexer = SyncMultiplexer(db, ctx, ctx.logger) + return await multiplexer.resync_from_source(sync_id=sync_id) + + +@router.get( + "/{sync_id}/destinations/active", + response_model=schemas.DestinationSlotInfo, + summary="Get active destination", + description="Get the currently active destination for a sync.", +) +async def get_active_destination( + sync_id: UUID, + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), +) -> schemas.DestinationSlotInfo: + """Get the active destination slot. + + Args: + sync_id: Sync ID + db: Database session + ctx: API context + + Returns: + Active destination info + + Raises: + HTTPException: If no active destination found + """ + _require_multiplexer_feature(ctx) + multiplexer = SyncMultiplexer(db, ctx, ctx.logger) + active = await multiplexer.get_active_destination(sync_id) + if not active: + raise HTTPException( + status_code=404, + detail=f"No active destination found for sync {sync_id}", + ) + return active diff --git a/backend/airweave/core/shared_models.py b/backend/airweave/core/shared_models.py index 2fbdfaf23..d030d8925 100644 --- a/backend/airweave/core/shared_models.py +++ b/backend/airweave/core/shared_models.py @@ -102,6 +102,7 @@ class FeatureFlag(str, Enum): PRIORITY_SUPPORT = "priority_support" SOURCE_RATE_LIMITING = "source_rate_limiting" ZEPHYR_SCALE = "zephyr_scale" # Enables Zephyr Scale test management sync for Jira + SYNC_MULTIPLEXER = "sync_multiplexer" # Destination multiplexing for migrations class AuthMethod(str, Enum): diff --git a/backend/airweave/crud/__init__.py b/backend/airweave/crud/__init__.py index a9feedb65..7c608b39f 100644 --- a/backend/airweave/crud/__init__.py +++ b/backend/airweave/crud/__init__.py @@ -20,6 +20,7 @@ from .crud_source_connection import source_connection from .crud_source_rate_limit import source_rate_limit from .crud_sync import sync +from .crud_sync_connection import sync_connection from .crud_sync_cursor import sync_cursor from .crud_sync_job import sync_job from .crud_transformer import transformer @@ -48,6 +49,7 @@ "source_connection", "source_rate_limit", "sync", + "sync_connection", "sync_cursor", "sync_job", "transformer", diff --git a/backend/airweave/crud/crud_sync_connection.py b/backend/airweave/crud/crud_sync_connection.py new file mode 100644 index 000000000..067d39328 --- /dev/null +++ b/backend/airweave/crud/crud_sync_connection.py @@ -0,0 +1,285 @@ +"""CRUD operations for sync connections. + +Provides methods for managing destination slots in the multiplexer. +""" + +from typing import List, Optional +from uuid import UUID + +from sqlalchemy import and_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from airweave.db.unit_of_work import UnitOfWork +from airweave.models.sync_connection import DestinationRole, SyncConnection + + +class CRUDSyncConnection: + """CRUD operations for sync connections. + + Note: SyncConnection doesn't have organization_id directly. + Access control should be enforced at the Sync level before calling these methods. + """ + + def __init__(self): + """Initialize the CRUD object.""" + self.model = SyncConnection + + async def get( + self, + db: AsyncSession, + id: UUID, + ) -> Optional[SyncConnection]: + """Get sync connection by ID. + + Args: + db: Database session + id: Sync connection ID + + Returns: + SyncConnection if found, None otherwise + """ + result = await db.execute(select(self.model).where(self.model.id == id)) + return result.scalar_one_or_none() + + async def get_by_sync_id( + self, + db: AsyncSession, + sync_id: UUID, + ) -> List[SyncConnection]: + """Get all sync connections for a sync. + + Args: + db: Database session + sync_id: Sync ID + + Returns: + List of sync connections + """ + result = await db.execute( + select(self.model) + .where(self.model.sync_id == sync_id) + .options(selectinload(self.model.connection)) + ) + return list(result.scalars().all()) + + async def get_by_sync_and_connection( + self, + db: AsyncSession, + sync_id: UUID, + connection_id: UUID, + ) -> Optional[SyncConnection]: + """Get sync connection by sync ID and connection ID. + + Args: + db: Database session + sync_id: Sync ID + connection_id: Connection ID + + Returns: + SyncConnection if found, None otherwise + """ + result = await db.execute( + select(self.model).where( + and_( + self.model.sync_id == sync_id, + self.model.connection_id == connection_id, + ) + ) + ) + return result.scalar_one_or_none() + + async def get_by_sync_and_role( + self, + db: AsyncSession, + sync_id: UUID, + role: DestinationRole, + ) -> List[SyncConnection]: + """Get sync connections by sync ID and role. + + Args: + db: Database session + sync_id: Sync ID + role: Destination role (active, shadow, deprecated) + + Returns: + List of sync connections with the specified role + """ + result = await db.execute( + select(self.model) + .where( + and_( + self.model.sync_id == sync_id, + self.model.role == role.value, + ) + ) + .options(selectinload(self.model.connection)) + ) + return list(result.scalars().all()) + + async def get_active_and_shadow( + self, + db: AsyncSession, + sync_id: UUID, + ) -> List[SyncConnection]: + """Get active and shadow sync connections for a sync. + + Used during sync to get all destinations that should receive writes. + + Args: + db: Database session + sync_id: Sync ID + + Returns: + List of active and shadow sync connections + """ + result = await db.execute( + select(self.model) + .where( + and_( + self.model.sync_id == sync_id, + self.model.role.in_( + [DestinationRole.ACTIVE.value, DestinationRole.SHADOW.value] + ), + ) + ) + .options(selectinload(self.model.connection)) + ) + return list(result.scalars().all()) + + async def create( + self, + db: AsyncSession, + *, + sync_id: UUID, + connection_id: UUID, + role: DestinationRole = DestinationRole.ACTIVE, + uow: Optional[UnitOfWork] = None, + ) -> SyncConnection: + """Create a new sync connection. + + Args: + db: Database session + sync_id: Sync ID + connection_id: Connection ID + role: Destination role (default: active) + uow: Optional unit of work for transaction control + + Returns: + Created sync connection + """ + db_obj = SyncConnection( + sync_id=sync_id, + connection_id=connection_id, + role=role.value, + ) + db.add(db_obj) + + if uow: + await uow.flush() + else: + await db.commit() + await db.refresh(db_obj) + + return db_obj + + async def update_role( + self, + db: AsyncSession, + *, + id: UUID, + role: DestinationRole, + uow: Optional[UnitOfWork] = None, + ) -> Optional[SyncConnection]: + """Update the role of a sync connection. + + Args: + db: Database session + id: Sync connection ID + role: New role + uow: Optional unit of work for transaction control + + Returns: + Updated sync connection + """ + await db.execute(update(self.model).where(self.model.id == id).values(role=role.value)) + + if uow: + await uow.flush() + else: + await db.commit() + + return await self.get(db, id=id) + + async def bulk_update_role( + self, + db: AsyncSession, + *, + sync_id: UUID, + from_role: DestinationRole, + to_role: DestinationRole, + uow: Optional[UnitOfWork] = None, + ) -> int: + """Bulk update roles for all sync connections with a specific role. + + Args: + db: Database session + sync_id: Sync ID + from_role: Current role to match + to_role: New role to set + uow: Optional unit of work for transaction control + + Returns: + Number of rows updated + """ + result = await db.execute( + update(self.model) + .where( + and_( + self.model.sync_id == sync_id, + self.model.role == from_role.value, + ) + ) + .values(role=to_role.value) + ) + + if uow: + await uow.flush() + else: + await db.commit() + + return result.rowcount + + async def remove( + self, + db: AsyncSession, + *, + id: UUID, + uow: Optional[UnitOfWork] = None, + ) -> bool: + """Remove a sync connection. + + Args: + db: Database session + id: Sync connection ID + uow: Optional unit of work for transaction control + + Returns: + True if deleted, False if not found + """ + db_obj = await self.get(db, id=id) + if not db_obj: + return False + + await db.delete(db_obj) + + if uow: + await uow.flush() + else: + await db.commit() + + return True + + +# Singleton instance +sync_connection = CRUDSyncConnection() diff --git a/backend/airweave/models/__init__.py b/backend/airweave/models/__init__.py index 2ade32103..2c423ce94 100644 --- a/backend/airweave/models/__init__.py +++ b/backend/airweave/models/__init__.py @@ -23,7 +23,7 @@ from .source_connection import SourceConnection from .source_rate_limit import SourceRateLimit from .sync import Sync -from .sync_connection import SyncConnection +from .sync_connection import DestinationRole, SyncConnection from .sync_cursor import SyncCursor from .sync_job import SyncJob from .transformer import Transformer @@ -56,6 +56,7 @@ "SourceConnection", "SourceRateLimit", "Sync", + "DestinationRole", "SyncConnection", "SyncCursor", "SyncJob", diff --git a/backend/airweave/models/sync_connection.py b/backend/airweave/models/sync_connection.py index f7ddbff3a..37c06400b 100644 --- a/backend/airweave/models/sync_connection.py +++ b/backend/airweave/models/sync_connection.py @@ -1,9 +1,10 @@ """Sync connection model.""" +from enum import Enum from typing import TYPE_CHECKING from uuid import UUID -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship from airweave.models._base import Base @@ -13,8 +14,23 @@ from airweave.models.sync import Sync +class DestinationRole(str, Enum): + """Role of a destination in a sync for multiplexing support. + + Used to enable blue-green deployments and migrations between vector DB configs. + """ + + ACTIVE = "active" # Receives writes + serves queries + SHADOW = "shadow" # Receives writes only (for migration testing) + DEPRECATED = "deprecated" # No longer in use (kept for rollback) + + class SyncConnection(Base): - """Sync connection model.""" + """Sync connection model. + + Links syncs to their source and destination connections. + For destinations, the `role` field enables multiplexing for migrations. + """ __tablename__ = "sync_connection" @@ -22,6 +38,11 @@ class SyncConnection(Base): connection_id: Mapped[UUID] = mapped_column( ForeignKey("connection.id", ondelete="CASCADE"), nullable=False ) + # Role for destination connections (active/shadow/deprecated) + # Used for blue-green deployments and migrations + role: Mapped[str] = mapped_column( + String(20), default=DestinationRole.ACTIVE.value, nullable=False + ) # Add relationship back to Sync sync: Mapped["Sync"] = relationship("Sync", back_populates="sync_connections") diff --git a/backend/airweave/platform/sync/__init__.py b/backend/airweave/platform/sync/__init__.py index 1bc1ea224..97ec43c30 100644 --- a/backend/airweave/platform/sync/__init__.py +++ b/backend/airweave/platform/sync/__init__.py @@ -1,13 +1,19 @@ """Sync module for Airweave. Provides: +- SyncFactory: Creates orchestrators (from sync.factory import SyncFactory) - SyncOrchestrator: Coordinates the entire sync workflow - EntityPipeline: Processes entities through transformation stages - SyncContext: Immutable container for sync resources - RawDataService: Stores raw entities with entity-level granularity + +Multiplexing (import from sync.multiplex directly): +- SyncMultiplexer: Manages multiple destinations per sync (migrations) +- ARFReplaySource: Pseudo-source for replaying from ARF +- replay_to_destination: Replays entities from ARF to destinations """ -from .raw_data import ( +from airweave.platform.sync.raw_data import ( RawDataService, SyncManifest, raw_data_service, diff --git a/backend/airweave/platform/sync/factory.py b/backend/airweave/platform/sync/factory.py deleted file mode 100644 index 804c7748c..000000000 --- a/backend/airweave/platform/sync/factory.py +++ /dev/null @@ -1,829 +0,0 @@ -"""Module for sync factory that creates context and orchestrator instances.""" - -import importlib -import time -from typing import Any, Optional -from uuid import UUID - -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave import crud, schemas -from airweave.api.context import ApiContext -from airweave.core import credentials -from airweave.core.config import settings -from airweave.core.constants.reserved_ids import NATIVE_QDRANT_UUID, RESERVED_TABLE_ENTITY_ID -from airweave.core.exceptions import NotFoundException -from airweave.core.guard_rail_service import GuardRailService -from airweave.core.logging import ContextualLogger, LoggerConfigurator, logger -from airweave.core.sync_cursor_service import sync_cursor_service -from airweave.db.init_db_native import init_db_with_entity_definitions -from airweave.platform.auth_providers._base import BaseAuthProvider -from airweave.platform.destinations._base import BaseDestination, ProcessingRequirement -from airweave.platform.entities._base import BaseEntity -from airweave.platform.locator import resource_locator -from airweave.platform.sources._base import BaseSource -from airweave.platform.sync.actions import ActionDispatcher, ActionResolver -from airweave.platform.sync.context import SyncContext -from airweave.platform.sync.cursor import SyncCursor -from airweave.platform.sync.entity_pipeline import EntityPipeline -from airweave.platform.sync.handlers import ( - PostgresMetadataHandler, - RawDataHandler, - VectorDBHandler, -) -from airweave.platform.sync.orchestrator import SyncOrchestrator -from airweave.platform.sync.pipeline.entity_tracker import EntityTracker -from airweave.platform.sync.state_publisher import SyncStatePublisher -from airweave.platform.sync.stream import AsyncSourceStream -from airweave.platform.sync.token_manager import TokenManager -from airweave.platform.sync.worker_pool import AsyncWorkerPool -from airweave.platform.utils.source_factory_utils import ( - get_auth_configuration, - process_credentials_for_source, -) - - -class SyncFactory: - """Factory for sync orchestrator.""" - - @classmethod - async def create_orchestrator( - cls, - db: AsyncSession, - sync: schemas.Sync, - sync_job: schemas.SyncJob, - collection: schemas.Collection, - connection: schemas.Connection, # Passed but unused - we load from DB - ctx: ApiContext, - access_token: Optional[str] = None, - max_workers: int = None, - force_full_sync: bool = False, - ) -> SyncOrchestrator: - """Create a dedicated orchestrator instance for a sync run. - - This method creates all necessary components for a sync run, including the - context and a dedicated orchestrator instance for concurrent execution. - - Args: - db: Database session - sync: The sync configuration - sync_job: The sync job - collection: The collection to sync to - connection: The connection (unused - we load source connection from DB) - ctx: The API context - access_token: Optional token to use instead of stored credentials - max_workers: Maximum number of concurrent workers (default: from settings) - force_full_sync: If True, forces a full sync with orphaned entity deletion - - Returns: - A dedicated SyncOrchestrator instance - """ - # Use configured value if max_workers not specified - if max_workers is None: - max_workers = settings.SYNC_MAX_WORKERS - logger.debug(f"Using configured max_workers: {max_workers}") - - # Track initialization timing - init_start = time.time() - - # Create sync context - logger.info("Creating sync context...") - context_start = time.time() - sync_context = await cls._create_sync_context( - db=db, - sync=sync, - sync_job=sync_job, - collection=collection, - connection=connection, # Unused parameter - ctx=ctx, - access_token=access_token, - force_full_sync=force_full_sync, - ) - logger.debug(f"Sync context created in {time.time() - context_start:.2f}s") - - # Create pipeline components - logger.debug("Initializing pipeline components...") - - # 1. Action Resolver - action_resolver = ActionResolver(entity_map=sync_context.entity_map) - - # 2. Handlers - grouped by destination processing requirements - handlers = cls._create_destination_handlers(sync_context) - handlers.append(RawDataHandler()) # Raw data storage - handlers.append(PostgresMetadataHandler()) # Metadata (runs last) - - # 3. Action Dispatcher - action_dispatcher = ActionDispatcher(handlers=handlers) - - # 4. Entity Pipeline - entity_pipeline = EntityPipeline( - entity_tracker=sync_context.entity_tracker, - action_resolver=action_resolver, - action_dispatcher=action_dispatcher, - ) - - # Create worker pool - worker_pool = AsyncWorkerPool(max_workers=max_workers, logger=sync_context.logger) - - stream = AsyncSourceStream( - source_generator=sync_context.source.generate_entities(), - queue_size=10000, # TODO: make this configurable - logger=sync_context.logger, - ) - - # Create dedicated orchestrator instance with all components - orchestrator = SyncOrchestrator( - entity_pipeline=entity_pipeline, - worker_pool=worker_pool, - stream=stream, - sync_context=sync_context, - ) - - logger.info(f"Total orchestrator initialization took {time.time() - init_start:.2f}s") - - return orchestrator - - @classmethod - async def _create_sync_context( - cls, - db: AsyncSession, - sync: schemas.Sync, - sync_job: schemas.SyncJob, - collection: schemas.Collection, - connection: schemas.Connection, - ctx: ApiContext, - access_token: Optional[str] = None, - force_full_sync: bool = False, - ) -> SyncContext: - """Create a sync context. - - Args: - db: Database session - sync: The sync configuration - sync_job: The sync job - collection: The collection to sync to - connection: The connection (unused - we load source connection from DB) - ctx: The API context - access_token: Optional token to use instead of stored credentials - force_full_sync: If True, forces a full sync with orphaned entity deletion - - Returns: - SyncContext object with all required components - """ - # Get source connection data first (includes source class with cursor schema) - source_connection_data = await cls._get_source_connection_data(db, sync, ctx) - - # Create a contextualized logger with all job metadata - logger = LoggerConfigurator.configure_logger( - "airweave.platform.sync", - dimensions={ - "sync_id": str(sync.id), - "sync_job_id": str(sync_job.id), - "organization_id": str(ctx.organization.id), - "source_connection_id": str(source_connection_data["connection_id"]), - "collection_readable_id": str(collection.readable_id), - "organization_name": ctx.organization.name, - "scheduled": str(sync_job.scheduled), - }, - ) - - source = await cls._create_source_instance_with_data( - db=db, - source_connection_data=source_connection_data, - ctx=ctx, - access_token=access_token, - logger=logger, # Pass the contextual logger - sync_job=sync_job, # Pass sync_job for file downloader temp directory setup - ) - destinations = await cls._create_destination_instances( - db=db, - sync=sync, - collection=collection, - ctx=ctx, - logger=logger, - ) - entity_map = await cls._get_entity_definition_map(db=db) - - # NEW: Load initial entity counts from database for state tracking - initial_counts = await crud.entity_count.get_counts_per_sync_and_type(db, sync.id) - - logger.info(f"🔢 Loaded initial entity counts: {len(initial_counts)} entity types") - - # Log the initial counts for debugging - for count in initial_counts: - logger.debug(f" - {count.entity_definition_name}: {count.count} entities") - - # Create EntityTracker (pure state tracking) - entity_tracker = EntityTracker( - job_id=sync_job.id, - sync_id=sync.id, - logger=logger, - initial_counts=initial_counts, - ) - - # Create SyncStatePublisher (handles pubsub publishing) - state_publisher = SyncStatePublisher( - job_id=sync_job.id, - sync_id=sync.id, - entity_tracker=entity_tracker, - logger=logger, - ) - - logger.info(f"✅ Created EntityTracker and SyncStatePublisher for job {sync_job.id}") - - logger.info("Sync context created") - - # Create GuardRailService with contextual logger - guard_rail = GuardRailService( - organization_id=ctx.organization.id, - logger=logger.with_context(component="guardrail"), - ) - - # Load existing cursor data from database - # IMPORTANT: When force_full_sync is True (daily cleanup), we intentionally - # skip loading cursor DATA to force a full sync. - # This ensures we see ALL entities in the source, not just changed ones, - # for accurate orphaned entity detection. We still track and save cursor - # values during the sync for the next incremental sync. - - # Get cursor schema from source class (direct reference, no string lookup!) - cursor_schema = None - source_class = source_connection_data["source_class"] - if hasattr(source_class, "_cursor_class") and source_class._cursor_class: - cursor_schema = source_class._cursor_class # Direct class reference - logger.debug(f"Source has typed cursor: {cursor_schema.__name__}") - - if force_full_sync: - logger.info( - "🔄 FORCE FULL SYNC: Skipping cursor data to ensure all entities are fetched " - "for accurate orphaned entity cleanup. Will still track cursor for next sync." - ) - cursor_data = None # Force full sync by not providing previous cursor data - else: - # Normal incremental sync - load cursor data - cursor_data = await sync_cursor_service.get_cursor_data(db=db, sync_id=sync.id, ctx=ctx) - if cursor_data: - logger.info(f"📊 Incremental sync: Using cursor data with {len(cursor_data)} keys") - - # Create typed cursor (no locator needed - direct class reference!) - cursor = SyncCursor( - sync_id=sync.id, - cursor_schema=cursor_schema, - cursor_data=cursor_data, - ) - - # Precompute destination keyword-index capability once - has_keyword_index = False - try: - import asyncio as _asyncio - - if destinations: - has_keyword_index = any( - await _asyncio.gather(*[dest.has_keyword_index() for dest in destinations]) - ) - except Exception as _e: - logger.warning(f"Failed to precompute keyword index capability on destinations: {_e}") - has_keyword_index = False - - # Create sync context - sync_context = SyncContext( - source=source, - destinations=destinations, - sync=sync, - sync_job=sync_job, - collection=collection, - connection=connection, # Unused parameter - entity_tracker=entity_tracker, - state_publisher=state_publisher, - cursor=cursor, - entity_map=entity_map, - ctx=ctx, - logger=logger, - guard_rail=guard_rail, - force_full_sync=force_full_sync, - has_keyword_index=has_keyword_index, - ) - - # Set cursor on source so it can access cursor data - source.set_cursor(cursor) - - return sync_context - - @classmethod - async def _create_source_instance_with_data( - cls, - db: AsyncSession, - source_connection_data: dict, - ctx: ApiContext, - logger: ContextualLogger, - access_token: Optional[str] = None, - sync_job: Optional[Any] = None, - ) -> BaseSource: - """Create and configure the source instance using pre-fetched connection data.""" - # Get auth configuration (credentials + proxy setup if needed) - auth_config = await get_auth_configuration( - db=db, - source_connection_data=source_connection_data, - ctx=ctx, - logger=logger, - access_token=access_token, - ) - - # Process credentials for source consumption - source_credentials = await process_credentials_for_source( - raw_credentials=auth_config["credentials"], - source_connection_data=source_connection_data, - logger=logger, - ) - - # Create the source instance with processed credentials - source = await source_connection_data["source_class"].create( - source_credentials, config=source_connection_data["config_fields"] - ) - - # Configure source with logger - if hasattr(source, "set_logger"): - source.set_logger(logger) - - # Set HTTP client factory if proxy is needed - if auth_config.get("http_client_factory"): - source.set_http_client_factory(auth_config["http_client_factory"]) - - # Step 4.1: Pass sync identifiers to the source for scoped helpers - try: - organization_id = ctx.organization.id - source_connection_id = source_connection_data.get("source_connection_id") - if hasattr(source, "set_sync_identifiers") and source_connection_id: - source.set_sync_identifiers( - organization_id=str(organization_id), - source_connection_id=str(source_connection_id), - ) - except Exception: - # Non-fatal: older sources may ignore this - pass - - # Setup token manager for OAuth sources (if applicable) - # Skip for: - # 1. Direct token injection (when access_token parameter was explicitly passed) - # 2. Proxy mode (PipedreamProxyClient or other proxies manage tokens internally) - from airweave.platform.auth_providers.auth_result import AuthProviderMode - - auth_mode = auth_config.get("auth_mode") - auth_provider_instance = auth_config.get("auth_provider_instance") - - # Check if we should skip TokenManager - is_direct_token_injection = access_token is not None - is_proxy_mode = auth_mode == AuthProviderMode.PROXY - - if not is_direct_token_injection and not is_proxy_mode: - try: - await cls._setup_token_manager( - db=db, - source=source, - source_connection_data=source_connection_data, - source_credentials=auth_config["credentials"], - ctx=ctx, - logger=logger, - auth_provider_instance=auth_provider_instance, - ) - except Exception as e: - logger.error( - f"Failed to setup token manager for source " - f"'{source_connection_data['short_name']}': {e}" - ) - # Don't fail source creation if token manager setup fails - elif is_proxy_mode: - logger.info( - f"⏭️ Skipping token manager for {source_connection_data['short_name']} - " - f"proxy mode (PipedreamProxyClient manages tokens internally)" - ) - else: - logger.debug( - f"⏭️ Skipping token manager for {source_connection_data['short_name']} - " - f"direct token injection" - ) - - # Setup file downloader for file-based sources - cls._setup_file_downloader(source, sync_job, logger) - - # Wrap HTTP client with AirweaveHttpClient for rate limiting - # This wraps whatever client is currently set (httpx or Pipedream proxy) - from airweave.platform.utils.source_factory_utils import wrap_source_with_airweave_client - - wrap_source_with_airweave_client( - source=source, - source_short_name=source_connection_data["short_name"], - source_connection_id=source_connection_data["source_connection_id"], - ctx=ctx, - logger=logger, - ) - - return source - - @classmethod - async def _get_source_connection_data( - cls, db: AsyncSession, sync: schemas.Sync, ctx: ApiContext - ) -> dict: - """Get source connection and model data.""" - # 1. Get SourceConnection first (has most of our data) - source_connection_obj = await crud.source_connection.get_by_sync_id( - db, sync_id=sync.id, ctx=ctx - ) - if not source_connection_obj: - raise NotFoundException( - f"Source connection record not found for sync {sync.id}. " - f"This typically occurs when a source connection is deleted while a " - f"scheduled workflow is queued. The workflow should self-destruct and " - f"clean up orphaned schedules." - ) - - # 2. Get Connection only to access integration_credential_id - connection = await crud.connection.get(db, source_connection_obj.connection_id, ctx) - if not connection: - raise NotFoundException("Connection not found") - - # 3. Get Source model using short_name from SourceConnection - source_model = await crud.source.get_by_short_name(db, source_connection_obj.short_name) - if not source_model: - raise NotFoundException(f"Source not found: {source_connection_obj.short_name}") - - # Get all fields from the RIGHT places: - config_fields = source_connection_obj.config_fields or {} # From SourceConnection - - # Pre-fetch to avoid lazy loading - convert to pure Python types - auth_config_class = source_model.auth_config_class - # Convert SQLAlchemy values to clean Python types to avoid lazy loading - source_connection_id = UUID(str(source_connection_obj.id)) # From SourceConnection - short_name = str(source_connection_obj.short_name) # From SourceConnection - connection_id = UUID(str(connection.id)) - - # Check if this connection uses an auth provider - readable_auth_provider_id = getattr( - source_connection_obj, "readable_auth_provider_id", None - ) - - # For auth provider connections, integration_credential_id will be None - # For regular connections, integration_credential_id must be set - if not readable_auth_provider_id and not connection.integration_credential_id: - raise NotFoundException(f"Connection {connection_id} has no integration credential") - - integration_credential_id = ( - UUID(str(connection.integration_credential_id)) - if connection.integration_credential_id - else None - ) - - source_class = resource_locator.get_source(source_model) - - # Pre-fetch oauth_type to avoid lazy loading issues - oauth_type = str(source_model.oauth_type) if source_model.oauth_type else None - - return { - "source_connection_obj": source_connection_obj, # The main entity - "connection": connection, # Just for credential access - "source_model": source_model, - "source_class": source_class, - "config_fields": config_fields, # From SourceConnection - "short_name": short_name, # From SourceConnection - "source_connection_id": source_connection_id, # Pre-fetched to avoid lazy loading - "auth_config_class": auth_config_class, - "connection_id": connection_id, - "integration_credential_id": integration_credential_id, # From Connection - "oauth_type": oauth_type, # Pre-fetched to avoid lazy loading - "readable_auth_provider_id": getattr( - source_connection_obj, "readable_auth_provider_id", None - ), - "auth_provider_config": getattr(source_connection_obj, "auth_provider_config", None), - } - - @classmethod - def _setup_file_downloader( - cls, source: BaseSource, sync_job: Optional[Any], logger: ContextualLogger - ) -> None: - """Setup file downloader for file-based sources. - - All sources get a file downloader (even API-only sources) since BaseSource - provides set_file_downloader(). Sources that don't download files simply - won't use it. - - Args: - source: Source instance to configure - sync_job: Sync job for temp directory organization (required) - logger: Logger for diagnostics - - Raises: - ValueError: If sync_job is None (programming error) - """ - from airweave.platform.downloader import FileDownloadService - - # Require sync_job - we're always in sync context when this is called - if not sync_job or not hasattr(sync_job, "id"): - raise ValueError( - "sync_job is required for file downloader initialization. " - "This method should only be called from create_orchestrator() " - "where sync_job exists." - ) - - file_downloader = FileDownloadService(sync_job_id=str(sync_job.id)) - source.set_file_downloader(file_downloader) - logger.debug( - f"File downloader configured for {source.__class__.__name__} " - f"(sync_job_id: {sync_job.id})" - ) - - @classmethod - async def _setup_token_manager( - cls, - db: AsyncSession, - source: BaseSource, - source_connection_data: dict, - source_credentials: any, - ctx: ApiContext, - logger: ContextualLogger, - auth_provider_instance: Optional[BaseAuthProvider] = None, - ) -> None: - """Set up token manager for OAuth sources.""" - short_name = source_connection_data["short_name"] - # Use pre-fetched oauth_type to avoid SQLAlchemy lazy loading issues - oauth_type = source_connection_data.get("oauth_type") - - # Determine if we should create a token manager based on oauth_type - should_create_token_manager = False - - if oauth_type: - # Import OAuthType enum - from airweave.schemas.source_connection import OAuthType - - # Only create token manager for sources that support token refresh - if oauth_type in (OAuthType.WITH_REFRESH, OAuthType.WITH_ROTATING_REFRESH): - should_create_token_manager = True - logger.debug( - f"✅ OAuth source {short_name} with oauth_type={oauth_type} " - f"will use token manager for refresh" - ) - else: - logger.debug( - f"⏭️ Skipping token manager for {short_name} - " - f"oauth_type={oauth_type} does not support token refresh" - ) - - if should_create_token_manager: - # Create a minimal connection object with only the fields needed by TokenManager - # Use pre-fetched IDs to avoid SQLAlchemy lazy loading issues - minimal_source_connection = type( - "SourceConnection", - (), - { - "id": source_connection_data["connection_id"], - "integration_credential_id": source_connection_data[ - "integration_credential_id" - ], - "config_fields": source_connection_data.get("config_fields"), - }, - )() - - token_manager = TokenManager( - db=db, - source_short_name=short_name, - source_connection=minimal_source_connection, - ctx=ctx, - initial_credentials=source_credentials, - is_direct_injection=False, # TokenManager will determine this internally - logger_instance=logger, - auth_provider_instance=auth_provider_instance, - ) - source.set_token_manager(token_manager) - - logger.info( - f"Token manager initialized for OAuth source {short_name} " - f"(auth_provider: {'Yes' if auth_provider_instance else 'None'})" - ) - else: - logger.debug( - f"Skipping token manager for {short_name} - " - "not an OAuth source or no access_token in credentials" - ) - - @classmethod - def _create_destination_handlers( - cls, - sync_context: SyncContext, - ) -> list: - """Create destination handlers grouped by processing requirements. - - This method groups destinations by their processing requirements and creates - appropriate handlers: - - VectorDBHandler: For destinations needing chunking/embedding (Qdrant, Pinecone) - - Args: - sync_context: Sync context with destinations and logger - - Returns: - List of destination handlers (may be empty if no destinations) - """ - from airweave.platform.sync.handlers.base import ActionHandler - - handlers: list[ActionHandler] = [] - - # Group destinations by processing requirement - vector_db_destinations: list[BaseDestination] = [] - self_processing_destinations: list[BaseDestination] = [] - - for dest in sync_context.destinations: - requirement = dest.processing_requirement - if requirement == ProcessingRequirement.CHUNKS_AND_EMBEDDINGS: - vector_db_destinations.append(dest) - elif requirement == ProcessingRequirement.RAW_ENTITIES: - self_processing_destinations.append(dest) - else: - # Default to vector DB for unknown requirements (backward compat) - sync_context.logger.warning( - f"Unknown processing requirement {requirement} for {dest.__class__.__name__}, " - "defaulting to CHUNKS_AND_EMBEDDINGS" - ) - vector_db_destinations.append(dest) - - # Create handlers for each non-empty group - if vector_db_destinations: - vector_handler = VectorDBHandler(destinations=vector_db_destinations) - handlers.append(vector_handler) - sync_context.logger.info( - f"Created VectorDBHandler for {len(vector_db_destinations)} destination(s): " - f"{[d.__class__.__name__ for d in vector_db_destinations]}" - ) - - if not handlers: - sync_context.logger.warning( - "No destination handlers created - sync has no valid destinations" - ) - - return handlers - - @classmethod - async def _create_destination_instances( # noqa: C901 - cls, - db: AsyncSession, - sync: schemas.Sync, - collection: schemas.Collection, - ctx: ApiContext, - logger: ContextualLogger, - ) -> list[BaseDestination]: - """Create destination instances with unified credentials pattern (matches sources). - - Handles two special cases: - 1. NATIVE_QDRANT_UUID: Uses settings, no credentials needed - 2. Org-specific destinations (e.g., S3): Loads credentials from Connection - - Args: - ----- - db (AsyncSession): The database session - sync (schemas.Sync): The sync object - collection (schemas.Collection): The collection object - ctx (ApiContext): The API context - logger (ContextualLogger): The contextual logger with sync metadata - - Returns: - -------- - list[BaseDestination]: A list of successfully created destination instances - - Raises: - ------- - ValueError: If no destinations could be created - """ - destinations = [] - - # Create all destinations from destination_connection_ids - for destination_connection_id in sync.destination_connection_ids: - try: - # Special case: Native Qdrant (uses settings, no DB connection) - if destination_connection_id == NATIVE_QDRANT_UUID: - logger.info("Using native Qdrant destination (settings-based)") - destination_model = await crud.destination.get_by_short_name(db, "qdrant") - if not destination_model: - logger.warning("Qdrant destination model not found") - continue - - destination_schema = schemas.Destination.model_validate(destination_model) - destination_class = resource_locator.get_destination(destination_schema) - - # Fail-fast: vector_size must be set - if collection.vector_size is None: - raise ValueError(f"Collection {collection.id} has no vector_size set.") - - # Native Qdrant: no credentials (uses settings) - destination = await destination_class.create( - credentials=None, - config=None, - collection_id=collection.id, - organization_id=collection.organization_id, - vector_size=collection.vector_size, - logger=logger, - ) - - destinations.append(destination) - logger.info("Created native Qdrant destination") - continue - - # Regular case: Load connection from database - destination_connection = await crud.connection.get( - db, destination_connection_id, ctx - ) - if not destination_connection: - logger.warning( - f"Destination connection {destination_connection_id} not found, skipping" - ) - continue - - destination_model = await crud.destination.get_by_short_name( - db, destination_connection.short_name - ) - if not destination_model: - logger.warning( - f"Destination {destination_connection.short_name} not found, skipping" - ) - continue - - # Load credentials (contains both auth and config) - destination_credentials = None - if ( - destination_model.auth_config_class - and destination_connection.integration_credential_id - ): - credential = await crud.integration_credential.get( - db, destination_connection.integration_credential_id, ctx - ) - if credential: - decrypted_credential = credentials.decrypt(credential.encrypted_credentials) - auth_config_class = resource_locator.get_auth_config( - destination_model.auth_config_class - ) - destination_credentials = auth_config_class.model_validate( - decrypted_credential - ) - - # Create destination instance with credentials (config=None) - destination_schema = schemas.Destination.model_validate(destination_model) - destination_class = resource_locator.get_destination(destination_schema) - - destination = await destination_class.create( - credentials=destination_credentials, - config=None, # Everything is in credentials for now - collection_id=collection.id, - organization_id=collection.organization_id, - logger=logger, - collection_readable_id=collection.readable_id, - sync_id=sync.id, - ) - - destinations.append(destination) - logger.info( - f"Created destination: {destination_connection.short_name} " - f"(connection_id={destination_connection_id})" - ) - except Exception as e: - # Log error but continue to next destination - logger.error( - f"Failed to create destination {destination_connection_id}: {e}", exc_info=True - ) - continue - - if not destinations: - raise ValueError( - "No valid destinations could be created for sync. " - f"Tried {len(sync.destination_connection_ids)} connection(s)." - ) - - logger.info( - f"Successfully created {len(destinations)} destination(s) " - f"out of {len(sync.destination_connection_ids)} configured" - ) - - return destinations - - # NOTE: Transformers removed - chunking now happens in VectorDBHandler - # (for destinations requiring CHUNKS_AND_EMBEDDINGS processing) - - @classmethod - async def _get_entity_definition_map(cls, db: AsyncSession) -> dict[type[BaseEntity], UUID]: - """Get entity definition map. - - Map entity class to entity definition id. - - Example key-value pair: - : entity_definition_id - """ - # Ensure the reserved polymorphic entity definition exists (idempotent). - await init_db_with_entity_definitions(db) - - entity_definitions = await crud.entity_definition.get_all(db) - - entity_definition_map = {} - for entity_definition in entity_definitions: - if entity_definition.id == RESERVED_TABLE_ENTITY_ID: - continue - full_module_name = f"airweave.platform.entities.{entity_definition.module_name}" - module = importlib.import_module(full_module_name) - entity_class = getattr(module, entity_definition.class_name) - entity_definition_map[entity_class] = entity_definition.id - - return entity_definition_map diff --git a/backend/airweave/platform/sync/factory/__init__.py b/backend/airweave/platform/sync/factory/__init__.py new file mode 100644 index 000000000..eaab08d14 --- /dev/null +++ b/backend/airweave/platform/sync/factory/__init__.py @@ -0,0 +1,17 @@ +"""Sync factory module - creates orchestrators for sync operations. + +Public API: + from airweave.platform.sync.factory import SyncFactory + + orchestrator = await SyncFactory.create_orchestrator(...) + +Internal builders (for sibling modules like multiplex/): + Import directly from private modules: + from airweave.platform.sync.factory._destination import DestinationBuilder + from airweave.platform.sync.factory._context import ReplayContextBuilder + from airweave.platform.sync.factory._pipeline import PipelineBuilder +""" + +from airweave.platform.sync.factory._factory import SyncFactory + +__all__ = ["SyncFactory"] diff --git a/backend/airweave/platform/sync/factory/_context.py b/backend/airweave/platform/sync/factory/_context.py new file mode 100644 index 000000000..c4df49493 --- /dev/null +++ b/backend/airweave/platform/sync/factory/_context.py @@ -0,0 +1,243 @@ +"""Context builder - creates SyncContext with all dependencies. + +This is an internal implementation detail of the factory module. +""" + +import asyncio + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import crud, schemas +from airweave.api.context import ApiContext +from airweave.core.guard_rail_service import GuardRailService +from airweave.core.logging import ContextualLogger +from airweave.core.sync_cursor_service import sync_cursor_service +from airweave.platform.destinations._base import BaseDestination +from airweave.platform.sources._base import BaseSource +from airweave.platform.sync.context import SyncContext +from airweave.platform.sync.cursor import SyncCursor +from airweave.platform.sync.pipeline.entity_tracker import EntityTracker +from airweave.platform.sync.state_publisher import SyncStatePublisher + + +class ContextBuilder: + """Builder for creating SyncContext with all its dependencies. + + Handles: + - EntityTracker creation with initial counts + - SyncStatePublisher setup + - Cursor loading (incremental vs full sync) + - GuardRailService creation + - Keyword index capability detection + """ + + def __init__( + self, + db: AsyncSession, + ctx: ApiContext, + logger: ContextualLogger, + ): + """Initialize the context builder.""" + self.db = db + self.ctx = ctx + self.logger = logger + + async def build( + self, + source: BaseSource, + source_connection_data: dict, + destinations: list[BaseDestination], + sync: schemas.Sync, + sync_job: schemas.SyncJob, + collection: schemas.Collection, + entity_map: dict, + force_full_sync: bool = False, + ) -> SyncContext: + """Build a complete SyncContext.""" + # 1. Load initial entity counts + initial_counts = await crud.entity_count.get_counts_per_sync_and_type(self.db, sync.id) + self.logger.info(f"🔢 Loaded initial entity counts: {len(initial_counts)} entity types") + + for count in initial_counts: + self.logger.debug(f" - {count.entity_definition_name}: {count.count} entities") + + # 2. Create EntityTracker + entity_tracker = EntityTracker( + job_id=sync_job.id, + sync_id=sync.id, + logger=self.logger, + initial_counts=initial_counts, + ) + + # 3. Create SyncStatePublisher + state_publisher = SyncStatePublisher( + job_id=sync_job.id, + sync_id=sync.id, + entity_tracker=entity_tracker, + logger=self.logger, + ) + + self.logger.info(f"✅ Created EntityTracker and SyncStatePublisher for job {sync_job.id}") + + # 4. Create GuardRailService + guard_rail = GuardRailService( + organization_id=self.ctx.organization.id, + logger=self.logger.with_context(component="guardrail"), + ) + + # 5. Create cursor + cursor = await self._create_cursor( + sync=sync, + source_connection_data=source_connection_data, + force_full_sync=force_full_sync, + ) + + # 6. Detect keyword index capability + has_keyword_index = await self._detect_keyword_index(destinations) + + # 7. Build SyncContext + sync_context = SyncContext( + source=source, + destinations=destinations, + sync=sync, + sync_job=sync_job, + collection=collection, + connection=None, + entity_tracker=entity_tracker, + state_publisher=state_publisher, + cursor=cursor, + entity_map=entity_map, + ctx=self.ctx, + logger=self.logger, + guard_rail=guard_rail, + force_full_sync=force_full_sync, + has_keyword_index=has_keyword_index, + ) + + # 8. Set cursor on source + source.set_cursor(cursor) + + self.logger.info("Sync context created") + return sync_context + + async def _create_cursor( + self, + sync: schemas.Sync, + source_connection_data: dict, + force_full_sync: bool, + ) -> SyncCursor: + """Create cursor with optional data loading.""" + cursor_schema = None + source_class = source_connection_data["source_class"] + if hasattr(source_class, "_cursor_class") and source_class._cursor_class: + cursor_schema = source_class._cursor_class + self.logger.debug(f"Source has typed cursor: {cursor_schema.__name__}") + + cursor_data = None + if force_full_sync: + self.logger.info( + "🔄 FORCE FULL SYNC: Skipping cursor data to ensure all entities are fetched " + "for accurate orphaned entity cleanup. Will still track cursor for next sync." + ) + else: + cursor_data = await sync_cursor_service.get_cursor_data( + db=self.db, sync_id=sync.id, ctx=self.ctx + ) + if cursor_data: + self.logger.info( + f"📊 Incremental sync: Using cursor data with {len(cursor_data)} keys" + ) + + return SyncCursor( + sync_id=sync.id, + cursor_schema=cursor_schema, + cursor_data=cursor_data, + ) + + async def _detect_keyword_index( + self, + destinations: list[BaseDestination], + ) -> bool: + """Detect if any destination has keyword index capability.""" + if not destinations: + return False + + try: + results = await asyncio.gather( + *[dest.has_keyword_index() for dest in destinations], + return_exceptions=True, + ) + return any(r is True for r in results if not isinstance(r, Exception)) + except Exception as e: + self.logger.warning(f"Failed to detect keyword index capability: {e}") + return False + + +class ReplayContextBuilder: + """Simplified context builder for replay operations. + + Creates a lightweight context without source connection data + or cursor tracking (since replay doesn't need incremental sync). + """ + + def __init__( + self, + db: AsyncSession, + ctx: ApiContext, + logger: ContextualLogger, + ): + """Initialize the replay context builder.""" + self.db = db + self.ctx = ctx + self.logger = logger + + async def build( + self, + source: BaseSource, + destinations: list[BaseDestination], + sync: schemas.Sync, + sync_job: schemas.SyncJob, + collection: schemas.Collection, + entity_map: dict, + ) -> SyncContext: + """Build a SyncContext for replay operations.""" + initial_counts = await crud.entity_count.get_counts_per_sync_and_type(self.db, sync.id) + + entity_tracker = EntityTracker( + job_id=sync_job.id, + sync_id=sync.id, + logger=self.logger, + initial_counts=initial_counts, + ) + + state_publisher = SyncStatePublisher( + job_id=sync_job.id, + sync_id=sync.id, + entity_tracker=entity_tracker, + logger=self.logger, + ) + + guard_rail = GuardRailService( + organization_id=self.ctx.organization.id, + logger=self.logger.with_context(component="guardrail"), + ) + + cursor = SyncCursor(sync_id=sync.id) + + return SyncContext( + source=source, + destinations=destinations, + sync=sync, + sync_job=sync_job, + collection=collection, + connection=None, + entity_tracker=entity_tracker, + state_publisher=state_publisher, + cursor=cursor, + entity_map=entity_map, + ctx=self.ctx, + logger=self.logger, + guard_rail=guard_rail, + force_full_sync=True, + has_keyword_index=False, + ) diff --git a/backend/airweave/platform/sync/factory/_destination.py b/backend/airweave/platform/sync/factory/_destination.py new file mode 100644 index 000000000..d14ec3f9d --- /dev/null +++ b/backend/airweave/platform/sync/factory/_destination.py @@ -0,0 +1,203 @@ +"""Destination builder - creates and configures destination instances. + +This is an internal implementation detail of the factory module. +""" + +from typing import Optional +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import crud, schemas +from airweave.api.context import ApiContext +from airweave.core import credentials +from airweave.core.constants.reserved_ids import NATIVE_QDRANT_UUID +from airweave.core.logging import ContextualLogger +from airweave.platform.destinations._base import BaseDestination +from airweave.platform.locator import resource_locator + + +class DestinationBuilder: + """Builder for creating destination instances. + + Handles: + - Loading destination connections + - Credential decryption + - Native Qdrant (settings-based) vs custom destinations + - Filtering by role (active/shadow/deprecated) + """ + + def __init__( + self, + db: AsyncSession, + ctx: ApiContext, + logger: ContextualLogger, + ): + """Initialize the destination builder.""" + self.db = db + self.ctx = ctx + self.logger = logger + + async def build( + self, + sync: schemas.Sync, + collection: schemas.Collection, + ) -> list[BaseDestination]: + """Build destination instances for a sync. + + Respects destination roles: + - ACTIVE: receives writes + serves queries + - SHADOW: receives writes only (migration testing) + - DEPRECATED: skipped (no writes) + """ + destination_ids = await self._get_active_ids(sync) + + destinations = await self.build_for_ids( + destination_ids=destination_ids, + collection=collection, + sync_id=sync.id, + ) + + if not destinations: + raise ValueError( + f"No valid destinations could be created for sync. " + f"Tried {len(sync.destination_connection_ids)} connection(s)." + ) + + self.logger.info( + f"Successfully created {len(destinations)} destination(s) " + f"out of {len(sync.destination_connection_ids)} configured" + ) + + return destinations + + async def build_for_ids( + self, + destination_ids: list[UUID], + collection: schemas.Collection, + sync_id: Optional[UUID] = None, + ) -> list[BaseDestination]: + """Build destination instances for specific connection IDs.""" + destinations = [] + + for dest_id in destination_ids: + try: + destination = await self._create_single( + dest_id=dest_id, + collection=collection, + sync_id=sync_id, + ) + if destination: + destinations.append(destination) + except Exception as e: + self.logger.error( + f"Failed to create destination {dest_id}: {e}", + exc_info=True, + ) + + return destinations + + async def _get_active_ids(self, sync: schemas.Sync) -> list[UUID]: + """Get destination IDs that should receive writes.""" + slots = await crud.sync_connection.get_active_and_shadow(self.db, sync_id=sync.id) + + if slots: + destination_ids = [slot.connection_id for slot in slots] + deprecated_count = len(sync.destination_connection_ids) - len(destination_ids) + if deprecated_count > 0: + self.logger.info( + f"Filtered {deprecated_count} deprecated destination(s), " + f"using {len(destination_ids)} active/shadow destination(s)" + ) + return destination_ids + + return list(sync.destination_connection_ids) + + async def _create_single( + self, + dest_id: UUID, + collection: schemas.Collection, + sync_id: Optional[UUID], + ) -> Optional[BaseDestination]: + """Create a single destination instance.""" + if dest_id == NATIVE_QDRANT_UUID: + return await self._create_native_qdrant(collection) + + return await self._create_custom_destination( + dest_id=dest_id, + collection=collection, + sync_id=sync_id, + ) + + async def _create_native_qdrant( + self, + collection: schemas.Collection, + ) -> Optional[BaseDestination]: + """Create native Qdrant destination (settings-based, no credentials).""" + self.logger.info("Using native Qdrant destination (settings-based)") + + destination_model = await crud.destination.get_by_short_name(self.db, "qdrant") + if not destination_model: + self.logger.warning("Qdrant destination model not found") + return None + + if collection.vector_size is None: + raise ValueError(f"Collection {collection.id} has no vector_size set.") + + dest_schema = schemas.Destination.model_validate(destination_model) + dest_class = resource_locator.get_destination(dest_schema) + + destination = await dest_class.create( + credentials=None, + config=None, + collection_id=collection.id, + organization_id=collection.organization_id, + vector_size=collection.vector_size, + logger=self.logger, + ) + + self.logger.info("Created native Qdrant destination") + return destination + + async def _create_custom_destination( + self, + dest_id: UUID, + collection: schemas.Collection, + sync_id: Optional[UUID], + ) -> Optional[BaseDestination]: + """Create a custom destination with credentials.""" + connection = await crud.connection.get(self.db, dest_id, self.ctx) + if not connection: + self.logger.warning(f"Destination connection {dest_id} not found, skipping") + return None + + destination_model = await crud.destination.get_by_short_name(self.db, connection.short_name) + if not destination_model: + self.logger.warning(f"Destination {connection.short_name} not found, skipping") + return None + + dest_credentials = None + if destination_model.auth_config_class and connection.integration_credential_id: + credential = await crud.integration_credential.get( + self.db, connection.integration_credential_id, self.ctx + ) + if credential: + decrypted = credentials.decrypt(credential.encrypted_credentials) + auth_config = resource_locator.get_auth_config(destination_model.auth_config_class) + dest_credentials = auth_config.model_validate(decrypted) + + dest_schema = schemas.Destination.model_validate(destination_model) + dest_class = resource_locator.get_destination(dest_schema) + + destination = await dest_class.create( + credentials=dest_credentials, + config=None, + collection_id=collection.id, + organization_id=collection.organization_id, + logger=self.logger, + collection_readable_id=collection.readable_id, + sync_id=sync_id, + ) + + self.logger.info(f"Created destination: {connection.short_name} (connection_id={dest_id})") + return destination diff --git a/backend/airweave/platform/sync/factory/_factory.py b/backend/airweave/platform/sync/factory/_factory.py new file mode 100644 index 000000000..5d7c4a53e --- /dev/null +++ b/backend/airweave/platform/sync/factory/_factory.py @@ -0,0 +1,180 @@ +"""SyncFactory - creates orchestrators for sync operations. + +This is the main factory class. See __init__.py for public exports. +""" + +import importlib +import time +from typing import Optional +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import crud, schemas +from airweave.api.context import ApiContext +from airweave.core.config import settings +from airweave.core.constants.reserved_ids import RESERVED_TABLE_ENTITY_ID +from airweave.core.logging import LoggerConfigurator, logger +from airweave.db.init_db_native import init_db_with_entity_definitions +from airweave.platform.entities._base import BaseEntity +from airweave.platform.sync.factory._context import ContextBuilder +from airweave.platform.sync.factory._destination import DestinationBuilder +from airweave.platform.sync.factory._pipeline import PipelineBuilder +from airweave.platform.sync.factory._source import SourceBuilder +from airweave.platform.sync.orchestrator import SyncOrchestrator +from airweave.platform.sync.stream import AsyncSourceStream +from airweave.platform.sync.worker_pool import AsyncWorkerPool + + +class SyncFactory: + """Factory for creating sync orchestrators. + + Example: + orchestrator = await SyncFactory.create_orchestrator( + db=db, + sync=sync, + sync_job=sync_job, + collection=collection, + connection=connection, + ctx=ctx, + ) + await orchestrator.run() + + For replay operations, see sync/multiplex/ module. + """ + + @classmethod + async def create_orchestrator( + cls, + db: AsyncSession, + sync: schemas.Sync, + sync_job: schemas.SyncJob, + collection: schemas.Collection, + connection: schemas.Connection, # Unused - kept for backwards compatibility + ctx: ApiContext, + access_token: Optional[str] = None, + max_workers: Optional[int] = None, + force_full_sync: bool = False, + ) -> SyncOrchestrator: + """Create a sync orchestrator with all required components. + + Args: + db: Database session + sync: Sync configuration + sync_job: Sync job for tracking + collection: Target collection + connection: Unused (kept for API compatibility) + ctx: API context + access_token: Optional token override + max_workers: Max concurrent workers (default: from settings) + force_full_sync: Whether to force full sync (skips cursor) + + Returns: + Configured SyncOrchestrator ready to run + """ + if max_workers is None: + max_workers = settings.SYNC_MAX_WORKERS + logger.debug(f"Using configured max_workers: {max_workers}") + + init_start = time.time() + + # 1. Create contextual logger + sync_logger = LoggerConfigurator.configure_logger( + "airweave.platform.sync", + dimensions={ + "sync_id": str(sync.id), + "sync_job_id": str(sync_job.id), + "organization_id": str(ctx.organization.id), + "collection_readable_id": str(collection.readable_id), + "organization_name": ctx.organization.name, + "scheduled": str(sync_job.scheduled), + }, + ) + + # 2. Build source + source_builder = SourceBuilder(db, ctx, sync_logger) + source, source_connection_data = await source_builder.build( + sync=sync, + access_token=access_token, + sync_job=sync_job, + ) + + # Update logger with source connection ID + sync_logger = LoggerConfigurator.configure_logger( + "airweave.platform.sync", + dimensions={ + "sync_id": str(sync.id), + "sync_job_id": str(sync_job.id), + "organization_id": str(ctx.organization.id), + "source_connection_id": str(source_connection_data["source_connection_id"]), + "collection_readable_id": str(collection.readable_id), + "organization_name": ctx.organization.name, + "scheduled": str(sync_job.scheduled), + }, + ) + + # 3. Build destinations + dest_builder = DestinationBuilder(db, ctx, sync_logger) + destinations = await dest_builder.build(sync=sync, collection=collection) + + # 4. Get entity map + entity_map = await cls._get_entity_definition_map(db) + + # 5. Build sync context + context_builder = ContextBuilder(db, ctx, sync_logger) + sync_context = await context_builder.build( + source=source, + source_connection_data=source_connection_data, + destinations=destinations, + sync=sync, + sync_job=sync_job, + collection=collection, + entity_map=entity_map, + force_full_sync=force_full_sync, + ) + + # 6. Build pipeline + entity_pipeline = PipelineBuilder.build( + sync_context=sync_context, + include_raw_data_handler=True, + ) + + # 7. Create worker pool and stream + worker_pool = AsyncWorkerPool(max_workers=max_workers, logger=sync_logger) + stream = AsyncSourceStream( + source_generator=source.generate_entities(), + queue_size=10000, + logger=sync_logger, + ) + + # 8. Create orchestrator + orchestrator = SyncOrchestrator( + entity_pipeline=entity_pipeline, + worker_pool=worker_pool, + stream=stream, + sync_context=sync_context, + ) + + logger.info(f"Total orchestrator initialization took {time.time() - init_start:.2f}s") + return orchestrator + + @classmethod + async def _get_entity_definition_map( + cls, + db: AsyncSession, + ) -> dict[type[BaseEntity], UUID]: + """Get entity class to definition ID map.""" + await init_db_with_entity_definitions(db) + + entity_definitions = await crud.entity_definition.get_all(db) + + entity_map = {} + for entity_def in entity_definitions: + if entity_def.id == RESERVED_TABLE_ENTITY_ID: + continue + full_module = f"airweave.platform.entities.{entity_def.module_name}" + module = importlib.import_module(full_module) + entity_class = getattr(module, entity_def.class_name) + entity_map[entity_class] = entity_def.id + + return entity_map diff --git a/backend/airweave/platform/sync/factory/_pipeline.py b/backend/airweave/platform/sync/factory/_pipeline.py new file mode 100644 index 000000000..3773e1d62 --- /dev/null +++ b/backend/airweave/platform/sync/factory/_pipeline.py @@ -0,0 +1,119 @@ +"""Pipeline builder - creates EntityPipeline with handlers. + +This is an internal implementation detail of the factory module. +""" + +from airweave.platform.destinations._base import BaseDestination, ProcessingRequirement +from airweave.platform.sync.actions import ActionDispatcher, ActionResolver +from airweave.platform.sync.context import SyncContext +from airweave.platform.sync.entity_pipeline import EntityPipeline +from airweave.platform.sync.handlers import ( + PostgresMetadataHandler, + RawDataHandler, + VectorDBHandler, +) +from airweave.platform.sync.handlers.base import ActionHandler +from airweave.platform.sync.pipeline.entity_tracker import EntityTracker + + +class PipelineBuilder: + """Builder for creating EntityPipeline with appropriate handlers. + + Handlers are created based on destination processing requirements: + - VectorDBHandler: For destinations needing chunking/embedding + - RawDataHandler: For ARF storage (skipped in replay mode) + - PostgresMetadataHandler: For entity metadata (always runs last) + """ + + @staticmethod + def build( + sync_context: SyncContext, + include_raw_data_handler: bool = True, + ) -> EntityPipeline: + """Build an EntityPipeline for a sync context.""" + action_resolver = ActionResolver(entity_map=sync_context.entity_map) + + handlers = PipelineBuilder._create_handlers( + destinations=sync_context.destinations, + logger=sync_context.logger, + include_raw_data_handler=include_raw_data_handler, + ) + + action_dispatcher = ActionDispatcher(handlers=handlers) + + return EntityPipeline( + entity_tracker=sync_context.entity_tracker, + action_resolver=action_resolver, + action_dispatcher=action_dispatcher, + ) + + @staticmethod + def build_for_replay( + entity_tracker: EntityTracker, + entity_map: dict, + destinations: list[BaseDestination], + ) -> EntityPipeline: + """Build a pipeline for replay operations. + + Replay pipelines skip RawDataHandler (we're reading from ARF, not writing). + """ + action_resolver = ActionResolver(entity_map=entity_map) + + handlers = PipelineBuilder._create_handlers( + destinations=destinations, + logger=None, + include_raw_data_handler=False, + ) + + action_dispatcher = ActionDispatcher(handlers=handlers) + + return EntityPipeline( + entity_tracker=entity_tracker, + action_resolver=action_resolver, + action_dispatcher=action_dispatcher, + ) + + @staticmethod + def _create_handlers( + destinations: list[BaseDestination], + logger, + include_raw_data_handler: bool = True, + ) -> list[ActionHandler]: + """Create handlers based on destination requirements.""" + handlers: list[ActionHandler] = [] + + vector_db_destinations: list[BaseDestination] = [] + self_processing_destinations: list[BaseDestination] = [] + + for dest in destinations: + requirement = dest.processing_requirement + if requirement == ProcessingRequirement.CHUNKS_AND_EMBEDDINGS: + vector_db_destinations.append(dest) + elif requirement == ProcessingRequirement.RAW_ENTITIES: + self_processing_destinations.append(dest) + else: + if logger: + logger.warning( + f"Unknown processing requirement {requirement} for " + f"{dest.__class__.__name__}, defaulting to CHUNKS_AND_EMBEDDINGS" + ) + vector_db_destinations.append(dest) + + if vector_db_destinations: + vector_handler = VectorDBHandler(destinations=vector_db_destinations) + handlers.append(vector_handler) + if logger: + logger.info( + f"Created VectorDBHandler for {len(vector_db_destinations)} destination(s): " + f"{[d.__class__.__name__ for d in vector_db_destinations]}" + ) + + if include_raw_data_handler: + handlers.append(RawDataHandler()) + + handlers.append(PostgresMetadataHandler()) + + if not handlers and logger: + logger.warning("No destination handlers created - sync has no valid destinations") + + return handlers diff --git a/backend/airweave/platform/sync/factory/_source.py b/backend/airweave/platform/sync/factory/_source.py new file mode 100644 index 000000000..52fa216a9 --- /dev/null +++ b/backend/airweave/platform/sync/factory/_source.py @@ -0,0 +1,287 @@ +"""Source builder - creates and configures source instances. + +This is an internal implementation detail of the factory module. +""" + +from typing import Any, Optional +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import crud, schemas +from airweave.api.context import ApiContext +from airweave.core.exceptions import NotFoundException +from airweave.core.logging import ContextualLogger +from airweave.platform.locator import resource_locator +from airweave.platform.sources._base import BaseSource +from airweave.platform.sync.token_manager import TokenManager +from airweave.platform.utils.source_factory_utils import ( + get_auth_configuration, + process_credentials_for_source, +) + + +class SourceBuilder: + """Builder for creating and configuring source instances. + + Handles: + - Loading source connection data + - Authentication (credentials, OAuth, auth providers) + - Token manager setup for OAuth refresh + - File downloader setup for file-based sources + - HTTP client wrapping for rate limiting + """ + + def __init__( + self, + db: AsyncSession, + ctx: ApiContext, + logger: ContextualLogger, + ): + """Initialize the source builder.""" + self.db = db + self.ctx = ctx + self.logger = logger + + async def build( + self, + sync: schemas.Sync, + access_token: Optional[str] = None, + sync_job: Optional[schemas.SyncJob] = None, + ) -> tuple[BaseSource, dict]: + """Build a fully configured source instance. + + Returns: + Tuple of (source instance, source connection data dict) + """ + # 1. Load source connection data + connection_data = await self._get_connection_data(sync) + + # 2. Get auth configuration + auth_config = await get_auth_configuration( + db=self.db, + source_connection_data=connection_data, + ctx=self.ctx, + logger=self.logger, + access_token=access_token, + ) + + # 3. Process credentials for source + source_credentials = await process_credentials_for_source( + raw_credentials=auth_config["credentials"], + source_connection_data=connection_data, + logger=self.logger, + ) + + # 4. Create source instance + source = await connection_data["source_class"].create( + source_credentials, config=connection_data["config_fields"] + ) + + # 5. Configure source + await self._configure_source( + source=source, + connection_data=connection_data, + auth_config=auth_config, + access_token=access_token, + sync_job=sync_job, + ) + + return source, connection_data + + async def _get_connection_data(self, sync: schemas.Sync) -> dict: + """Load source connection and related data.""" + # 1. Get SourceConnection first + source_connection = await crud.source_connection.get_by_sync_id( + self.db, sync_id=sync.id, ctx=self.ctx + ) + if not source_connection: + raise NotFoundException( + f"Source connection record not found for sync {sync.id}. " + f"This typically occurs when a source connection is deleted while a " + f"scheduled workflow is queued." + ) + + # 2. Get Connection for integration_credential_id + connection = await crud.connection.get(self.db, source_connection.connection_id, self.ctx) + if not connection: + raise NotFoundException("Connection not found") + + # 3. Get Source model + source_model = await crud.source.get_by_short_name(self.db, source_connection.short_name) + if not source_model: + raise NotFoundException(f"Source not found: {source_connection.short_name}") + + # 4. Pre-fetch to avoid lazy loading + config_fields = source_connection.config_fields or {} + auth_config_class = source_model.auth_config_class + source_connection_id = UUID(str(source_connection.id)) + short_name = str(source_connection.short_name) + connection_id = UUID(str(connection.id)) + + readable_auth_provider_id = getattr(source_connection, "readable_auth_provider_id", None) + + if not readable_auth_provider_id and not connection.integration_credential_id: + raise NotFoundException(f"Connection {connection_id} has no integration credential") + + integration_credential_id = ( + UUID(str(connection.integration_credential_id)) + if connection.integration_credential_id + else None + ) + + source_class = resource_locator.get_source(source_model) + oauth_type = str(source_model.oauth_type) if source_model.oauth_type else None + + return { + "source_connection_obj": source_connection, + "connection": connection, + "source_model": source_model, + "source_class": source_class, + "config_fields": config_fields, + "short_name": short_name, + "source_connection_id": source_connection_id, + "auth_config_class": auth_config_class, + "connection_id": connection_id, + "integration_credential_id": integration_credential_id, + "oauth_type": oauth_type, + "readable_auth_provider_id": readable_auth_provider_id, + "auth_provider_config": getattr(source_connection, "auth_provider_config", None), + } + + async def _configure_source( + self, + source: BaseSource, + connection_data: dict, + auth_config: dict, + access_token: Optional[str], + sync_job: Optional[schemas.SyncJob], + ) -> None: + """Configure source with logger, clients, token manager, and file downloader.""" + short_name = connection_data["short_name"] + + if hasattr(source, "set_logger"): + source.set_logger(self.logger) + + if auth_config.get("http_client_factory"): + source.set_http_client_factory(auth_config["http_client_factory"]) + + # Set sync identifiers + try: + source_connection_id = connection_data.get("source_connection_id") + if hasattr(source, "set_sync_identifiers") and source_connection_id: + source.set_sync_identifiers( + organization_id=str(self.ctx.organization.id), + source_connection_id=str(source_connection_id), + ) + except Exception: + pass + + # Setup token manager + await self._setup_token_manager( + source=source, + connection_data=connection_data, + credentials=auth_config["credentials"], + auth_config=auth_config, + access_token=access_token, + ) + + # Setup file downloader + self._setup_file_downloader(source, sync_job) + + # Wrap HTTP client + from airweave.platform.utils.source_factory_utils import wrap_source_with_airweave_client + + wrap_source_with_airweave_client( + source=source, + source_short_name=short_name, + source_connection_id=connection_data["source_connection_id"], + ctx=self.ctx, + logger=self.logger, + ) + + async def _setup_token_manager( + self, + source: BaseSource, + connection_data: dict, + credentials: Any, + auth_config: dict, + access_token: Optional[str], + ) -> None: + """Set up token manager for OAuth sources.""" + from airweave.platform.auth_providers.auth_result import AuthProviderMode + from airweave.schemas.source_connection import OAuthType + + short_name = connection_data["short_name"] + oauth_type = connection_data.get("oauth_type") + auth_mode = auth_config.get("auth_mode") + auth_provider_instance = auth_config.get("auth_provider_instance") + + if access_token is not None: + self.logger.debug(f"⏭️ Skipping token manager for {short_name} - direct token injection") + return + + if auth_mode == AuthProviderMode.PROXY: + self.logger.info( + f"⏭️ Skipping token manager for {short_name} - " + f"proxy mode (PipedreamProxyClient manages tokens internally)" + ) + return + + if oauth_type not in (OAuthType.WITH_REFRESH, OAuthType.WITH_ROTATING_REFRESH): + self.logger.debug( + f"⏭️ Skipping token manager for {short_name} - " + f"oauth_type={oauth_type} does not support refresh" + ) + return + + try: + minimal_connection = type( + "SourceConnection", + (), + { + "id": connection_data["connection_id"], + "integration_credential_id": connection_data["integration_credential_id"], + "config_fields": connection_data.get("config_fields"), + }, + )() + + token_manager = TokenManager( + db=self.db, + source_short_name=short_name, + source_connection=minimal_connection, + ctx=self.ctx, + initial_credentials=credentials, + is_direct_injection=False, + logger_instance=self.logger, + auth_provider_instance=auth_provider_instance, + ) + source.set_token_manager(token_manager) + + self.logger.info( + f"Token manager initialized for OAuth source {short_name} " + f"(auth_provider: {'Yes' if auth_provider_instance else 'None'})" + ) + except Exception as e: + self.logger.error(f"Failed to setup token manager for {short_name}: {e}") + + def _setup_file_downloader( + self, + source: BaseSource, + sync_job: Optional[schemas.SyncJob], + ) -> None: + """Setup file downloader for file-based sources.""" + from airweave.platform.downloader import FileDownloadService + + if not sync_job or not hasattr(sync_job, "id"): + raise ValueError( + "sync_job is required for file downloader initialization. " + "This method should only be called from create_orchestrator()." + ) + + file_downloader = FileDownloadService(sync_job_id=str(sync_job.id)) + source.set_file_downloader(file_downloader) + self.logger.debug( + f"File downloader configured for {source.__class__.__name__} " + f"(sync_job_id: {sync_job.id})" + ) diff --git a/backend/airweave/platform/sync/multiplex/__init__.py b/backend/airweave/platform/sync/multiplex/__init__.py new file mode 100644 index 000000000..c6122f96f --- /dev/null +++ b/backend/airweave/platform/sync/multiplex/__init__.py @@ -0,0 +1,29 @@ +"""Sync multiplexing module - manages multiple destinations and replay operations. + +This module provides: +- SyncMultiplexer: Manages destination slots for migrations and blue-green deployments +- ARFReplaySource: Pseudo-source that reads from ARF storage +- replay_to_destination: Replays entities to a specific destination + +Typical migration workflow: +1. multiplexer.resync_from_source() - Ensure ARF is up-to-date +2. multiplexer.fork() - Create shadow destination, optionally replay from ARF +3. Validate shadow destination (search quality, etc.) +4. multiplexer.switch() - Promote shadow to active +5. (Optional) cleanup deprecated destinations +""" + +from airweave.platform.sync.multiplex.multiplexer import SyncMultiplexer, get_multiplexer +from airweave.platform.sync.multiplex.replay import ( + ARFReplaySource, + create_replay_orchestrator, + replay_to_destination, +) + +__all__ = [ + "SyncMultiplexer", + "get_multiplexer", + "ARFReplaySource", + "replay_to_destination", + "create_replay_orchestrator", +] diff --git a/backend/airweave/platform/sync/multiplex/multiplexer.py b/backend/airweave/platform/sync/multiplex/multiplexer.py new file mode 100644 index 000000000..559b90902 --- /dev/null +++ b/backend/airweave/platform/sync/multiplex/multiplexer.py @@ -0,0 +1,403 @@ +"""Sync multiplexer for managing multiple destinations per sync. + +Enables blue-green deployments and migrations between vector DB configs: +- Fork: Create shadow destination, replay from ARF +- Switch: Promote shadow to active +- Resync: Force full sync from source to refresh ARF +- List: Show all destinations and their roles +""" + +from typing import List, Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import crud, schemas +from airweave.api.context import ApiContext +from airweave.core.logging import ContextualLogger +from airweave.core.shared_models import IntegrationType +from airweave.db.unit_of_work import UnitOfWork +from airweave.models.sync_connection import DestinationRole, SyncConnection +from airweave.platform.sync.raw_data import raw_data_service + + +class SyncMultiplexer: + """Manages destination slots for a sync. + + Use cases: + - Migration: Qdrant → Vespa, Vespa v0 → Vespa v1 + - A/B testing: Compare search quality across configs + - Rollback: Keep previous destination available + + Typical workflow: + 1. resync_from_source() - Ensure ARF is up-to-date + 2. fork() - Create shadow destination, optionally replay from ARF + 3. Validate shadow destination (search quality, etc.) + 4. switch() - Promote shadow to active + 5. (Optional) cleanup deprecated destinations + """ + + def __init__(self, db: AsyncSession, ctx: ApiContext, logger: ContextualLogger): + """Initialize the multiplexer. + + Args: + db: Database session + ctx: API context (used for access control) + logger: Contextual logger + """ + self.db = db + self.ctx = ctx + self.logger = logger + + # ========================================================================= + # Fork: Create shadow destination + # ========================================================================= + + async def fork( + self, + sync_id: UUID, + destination_connection_id: UUID, + replay_from_arf: bool = False, + ) -> tuple[SyncConnection, Optional[schemas.SyncJob]]: + """Create a shadow destination and optionally populate from ARF. + + Args: + sync_id: Sync to fork destination for + destination_connection_id: New destination connection to add + replay_from_arf: If True, kicks off replay job from ARF + + Returns: + Tuple of (SyncConnection, Optional[SyncJob]) + - SyncJob is returned if replay_from_arf=True + + Raises: + HTTPException: If sync not found, destination invalid, or already exists + """ + # 1. Validate sync exists and user has access + sync = await crud.sync.get(self.db, id=sync_id, ctx=self.ctx, with_connections=False) + if not sync: + raise HTTPException(status_code=404, detail=f"Sync {sync_id} not found") + + # 2. Check destination connection exists and is a valid destination + dest_conn = await crud.connection.get(self.db, id=destination_connection_id, ctx=self.ctx) + if not dest_conn: + raise HTTPException( + status_code=404, detail=f"Connection {destination_connection_id} not found" + ) + if dest_conn.integration_type != IntegrationType.DESTINATION.value: + raise HTTPException( + status_code=400, + detail=f"Connection {destination_connection_id} is not a destination", + ) + + # 3. Check if slot already exists + existing = await crud.sync_connection.get_by_sync_and_connection( + self.db, sync_id=sync_id, connection_id=destination_connection_id + ) + if existing: + raise HTTPException( + status_code=400, + detail=f"Destination {destination_connection_id} already exists for sync {sync_id}", + ) + + # 4. Create shadow slot + async with UnitOfWork(self.db) as uow: + slot = await crud.sync_connection.create( + self.db, + sync_id=sync_id, + connection_id=destination_connection_id, + role=DestinationRole.SHADOW, + uow=uow, + ) + + # Also update sync.destination_connection_ids to include the new destination + # This ensures backward compatibility with existing sync flow + current_dest_ids = list(sync.destination_connection_ids or []) + if destination_connection_id not in current_dest_ids: + current_dest_ids.append(destination_connection_id) + await crud.sync.update( + self.db, + db_obj=sync, + obj_in=schemas.SyncUpdate(destination_connection_ids=current_dest_ids), + ctx=self.ctx, + uow=uow, + ) + + await uow.commit() + + self.logger.info( + f"Created shadow slot for sync {sync_id} → {dest_conn.name}", + extra={"slot_id": str(slot.id), "destination_id": str(destination_connection_id)}, + ) + + # 5. Kick off replay if requested + replay_job = None + if replay_from_arf: + replay_job = await self._start_replay_job(sync_id, slot.id, destination_connection_id) + + return slot, replay_job + + # ========================================================================= + # Switch: Promote shadow to active + # ========================================================================= + + async def switch( + self, + sync_id: UUID, + new_active_slot_id: UUID, + ) -> schemas.SwitchDestinationResponse: + """Promote a shadow destination to active. + + Args: + sync_id: Sync to switch + new_active_slot_id: Slot ID to promote + + Returns: + SwitchDestinationResponse with status and slot IDs + + Raises: + HTTPException: If slot not found or not a shadow + """ + # 1. Validate sync exists and user has access + sync = await crud.sync.get(self.db, id=sync_id, ctx=self.ctx, with_connections=False) + if not sync: + raise HTTPException(status_code=404, detail=f"Sync {sync_id} not found") + + # 2. Get all slots for this sync + slots = await crud.sync_connection.get_by_sync_id(self.db, sync_id=sync_id) + + current_active = next((s for s in slots if s.role == DestinationRole.ACTIVE.value), None) + target_slot = next((s for s in slots if s.id == new_active_slot_id), None) + + if not target_slot: + raise HTTPException( + status_code=404, detail=f"Slot {new_active_slot_id} not found for sync {sync_id}" + ) + if target_slot.role != DestinationRole.SHADOW.value: + raise HTTPException( + status_code=400, + detail=f"Slot {new_active_slot_id} is not a shadow (current: {target_slot.role})", + ) + + # 3. Perform the switch atomically + async with UnitOfWork(self.db) as uow: + previous_active_id = None + + # Demote current active (if exists) + if current_active: + await crud.sync_connection.update_role( + self.db, id=current_active.id, role=DestinationRole.DEPRECATED, uow=uow + ) + previous_active_id = current_active.id + self.logger.info( + f"Demoted slot {current_active.id} to DEPRECATED", + extra={"destination_id": str(current_active.connection_id)}, + ) + + # Promote target to active + await crud.sync_connection.update_role( + self.db, id=target_slot.id, role=DestinationRole.ACTIVE, uow=uow + ) + self.logger.info( + f"Promoted slot {target_slot.id} to ACTIVE", + extra={"destination_id": str(target_slot.connection_id)}, + ) + + await uow.commit() + + return schemas.SwitchDestinationResponse( + status="switched", + new_active_slot_id=new_active_slot_id, + previous_active_slot_id=previous_active_id, + ) + + # ========================================================================= + # List: Show all destinations and their roles + # ========================================================================= + + async def list_destinations( + self, + sync_id: UUID, + ) -> List[schemas.DestinationSlotInfo]: + """List all destinations for a sync with their roles. + + Args: + sync_id: Sync ID + + Returns: + List of DestinationSlotInfo sorted by role (active first) + """ + # 1. Validate sync exists and user has access + sync = await crud.sync.get(self.db, id=sync_id, ctx=self.ctx, with_connections=False) + if not sync: + raise HTTPException(status_code=404, detail=f"Sync {sync_id} not found") + + # 2. Get all slots + slots = await crud.sync_connection.get_by_sync_id(self.db, sync_id=sync_id) + + # 3. Get ARF stats for entity count + arf_stats = await raw_data_service.get_replay_stats(str(sync_id)) + entity_count = arf_stats.get("entity_count", 0) if arf_stats.get("exists") else 0 + + # 4. Build response + result = [] + for slot in slots: + # Get connection details + conn = await crud.connection.get(self.db, id=slot.connection_id, ctx=self.ctx) + if not conn: + continue # Skip if connection was deleted + + result.append( + schemas.DestinationSlotInfo( + slot_id=slot.id, + destination_connection_id=slot.connection_id, + destination_name=conn.name, + destination_short_name=conn.short_name, + role=DestinationRole(slot.role), + created_at=slot.created_at, + entity_count=entity_count, + ) + ) + + # Sort: ACTIVE → SHADOW → DEPRECATED + role_order = { + DestinationRole.ACTIVE: 0, + DestinationRole.SHADOW: 1, + DestinationRole.DEPRECATED: 2, + } + result.sort(key=lambda x: role_order.get(x.role, 99)) + + return result + + async def get_active_destination( + self, + sync_id: UUID, + ) -> Optional[schemas.DestinationSlotInfo]: + """Get the active destination for queries. + + Args: + sync_id: Sync ID + + Returns: + Active destination info, or None if no active destination + """ + slots = await self.list_destinations(sync_id) + return next((s for s in slots if s.role == DestinationRole.ACTIVE), None) + + # ========================================================================= + # Resync: Force full sync from source to refresh ARF + # ========================================================================= + + async def resync_from_source( + self, + sync_id: UUID, + ) -> schemas.SyncJob: + """Trigger full sync from source to refresh ARF. + + Ensures ARF is up-to-date before forking to a new destination. + Uses force_full_sync=True to bypass cursor and get all entities. + + Args: + sync_id: Sync ID + + Returns: + SyncJob for tracking progress + """ + from airweave.core import source_connection_service + + # 1. Validate sync exists + sync = await crud.sync.get(self.db, id=sync_id, ctx=self.ctx, with_connections=False) + if not sync: + raise HTTPException(status_code=404, detail=f"Sync {sync_id} not found") + + # 2. Get source connection for this sync + source_conn = await crud.source_connection.get_by_sync_id( + self.db, sync_id=sync_id, ctx=self.ctx + ) + if not source_conn: + raise HTTPException( + status_code=404, detail=f"No source connection found for sync {sync_id}" + ) + + self.logger.info( + "Triggering full resync from source for ARF refresh", + extra={"sync_id": str(sync_id), "source_connection_id": str(source_conn.id)}, + ) + + # 3. Trigger via existing service (force_full_sync=True) + job = await source_connection_service.run( + self.db, + id=source_conn.id, + ctx=self.ctx, + force_full_sync=True, + ) + + # Convert SourceConnectionJob to SyncJob + return schemas.SyncJob( + id=job.id, + sync_id=sync_id, + organization_id=self.ctx.organization.id, + status=job.status, + started_at=job.started_at, + completed_at=job.completed_at, + entities_inserted=job.entities_inserted, + entities_updated=job.entities_updated, + entities_deleted=job.entities_deleted, + entities_kept=job.entities_kept, + entities_skipped=job.entities_skipped, + ) + + # ========================================================================= + # Private: Replay job management + # ========================================================================= + + async def _start_replay_job( + self, + sync_id: UUID, + target_slot_id: UUID, + target_destination_id: UUID, + ) -> schemas.SyncJob: + """Start ARF replay job to populate a shadow destination. + + Uses the SyncOrchestrator with ARFReplaySource for efficient replay + that reuses all existing pipeline logic. + + Args: + sync_id: Sync ID + target_slot_id: Slot ID to replay to + target_destination_id: Destination connection ID + + Returns: + SyncJob tracking the replay progress + """ + from airweave.platform.sync.multiplex.replay import replay_to_destination + + self.logger.info( + f"Starting ARF replay for slot {target_slot_id}", + extra={ + "sync_id": str(sync_id), + "target_slot_id": str(target_slot_id), + "target_destination_id": str(target_destination_id), + }, + ) + + return await replay_to_destination( + db=self.db, + ctx=self.ctx, + sync_id=sync_id, + target_connection_id=target_destination_id, + ) + + +async def get_multiplexer(db: AsyncSession, ctx: ApiContext) -> SyncMultiplexer: + """Get a SyncMultiplexer instance. + + Args: + db: Database session + ctx: API context + + Returns: + SyncMultiplexer instance + """ + return SyncMultiplexer(db, ctx, ctx.logger) diff --git a/backend/airweave/platform/sync/multiplex/replay.py b/backend/airweave/platform/sync/multiplex/replay.py new file mode 100644 index 000000000..f72ea3234 --- /dev/null +++ b/backend/airweave/platform/sync/multiplex/replay.py @@ -0,0 +1,246 @@ +"""Replay service - populates destinations from ARF storage. + +Enables replaying raw entities from ARF to new destinations without +hitting the source again. Used for migration workflows. + +Design: Uses builders + SyncOrchestrator with ARFReplaySource that +reads from ARF instead of an external source. +""" + +import time +from typing import AsyncGenerator +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import crud, schemas +from airweave.api.context import ApiContext +from airweave.core.exceptions import NotFoundException +from airweave.core.logging import ContextualLogger, LoggerConfigurator +from airweave.core.shared_models import SyncJobStatus +from airweave.db.unit_of_work import UnitOfWork +from airweave.platform.entities._base import BaseEntity +from airweave.platform.sources._base import BaseSource +from airweave.platform.sync.factory import SyncFactory +from airweave.platform.sync.factory._context import ReplayContextBuilder +from airweave.platform.sync.factory._destination import DestinationBuilder +from airweave.platform.sync.factory._pipeline import PipelineBuilder +from airweave.platform.sync.orchestrator import SyncOrchestrator +from airweave.platform.sync.raw_data import raw_data_service +from airweave.platform.sync.stream import AsyncSourceStream +from airweave.platform.sync.worker_pool import AsyncWorkerPool + + +class ARFReplaySource(BaseSource): + """Pseudo-source that reads entities from ARF storage. + + This allows reusing the SyncOrchestrator pipeline for replay operations. + Instead of fetching from an external API, it iterates over the ARF store. + """ + + _name = "ARF Replay" + _short_name = "arf_replay" + _auth_type = None + + def __init__(self, sync_id: str, logger: ContextualLogger): + """Initialize ARF replay source. + + Args: + sync_id: Sync ID to replay from + logger: Contextual logger + """ + self._sync_id = sync_id + self.logger = logger + + async def generate_entities(self) -> AsyncGenerator[BaseEntity, None]: + """Generate entities from ARF store. + + Yields: + BaseEntity instances reconstructed from ARF + """ + self.logger.info(f"Starting ARF replay for sync {self._sync_id}") + count = 0 + + async for entity in raw_data_service.iter_entities_for_replay(self._sync_id): + count += 1 + if count % 100 == 0: + self.logger.debug(f"Replayed {count} entities from ARF") + yield entity + + self.logger.info(f"ARF replay completed: {count} entities yielded") + + +async def replay_to_destination( + db: AsyncSession, + ctx: ApiContext, + sync_id: UUID, + target_connection_id: UUID, +) -> schemas.SyncJob: + """Replay entities from ARF to a specific destination. + + High-level API that creates an orchestrator and runs it. + + Args: + db: Database session + ctx: API context + sync_id: Sync ID to replay from + target_connection_id: Destination connection ID + + Returns: + SyncJob tracking the replay progress + + Raises: + ValueError: If ARF store doesn't exist + NotFoundException: If sync or destination not found + """ + orchestrator = await create_replay_orchestrator( + db=db, + ctx=ctx, + sync_id=sync_id, + target_connection_id=target_connection_id, + ) + + await orchestrator.run() + return orchestrator.sync_context.sync_job + + +async def create_replay_orchestrator( + db: AsyncSession, + ctx: ApiContext, + sync_id: UUID, + target_connection_id: UUID, + max_workers: int = None, +) -> SyncOrchestrator: + """Create an orchestrator for ARF replay operations. + + Uses builders for modular construction: + - DestinationBuilder: Creates target destination + - ReplayContextBuilder: Creates lightweight context + - PipelineBuilder: Creates pipeline (without RawDataHandler) + + Args: + db: Database session + ctx: API context + sync_id: Sync ID to replay from + target_connection_id: Target destination connection ID + max_workers: Max concurrent workers + + Returns: + SyncOrchestrator configured for replay + """ + from airweave.core.config import settings + + if max_workers is None: + max_workers = settings.SYNC_MAX_WORKERS + + init_start = time.time() + + # 1. Validate ARF store exists + arf_stats = await raw_data_service.get_replay_stats(str(sync_id)) + if not arf_stats.get("exists"): + raise ValueError(f"No ARF store found for sync {sync_id}") + + entity_count = arf_stats.get("entity_count", 0) + + # 2. Get sync and source connection + sync = await crud.sync.get(db, id=sync_id, ctx=ctx, with_connections=True) + if not sync: + raise NotFoundException(f"Sync {sync_id} not found") + + source_conn = await crud.source_connection.get_by_sync_id(db, sync_id=sync_id, ctx=ctx) + if not source_conn: + raise NotFoundException(f"No source connection found for sync {sync_id}") + + collection = await crud.collection.get_by_readable_id( + db, readable_id=source_conn.readable_collection_id, ctx=ctx + ) + if not collection: + raise NotFoundException(f"Collection not found for sync {sync_id}") + + collection_schema = schemas.Collection.model_validate(collection, from_attributes=True) + + # 3. Create sync job + async with UnitOfWork(db) as uow: + sync_job = await crud.sync_job.create( + db, + obj_in=schemas.SyncJobCreate( + sync_id=sync_id, + status=SyncJobStatus.PENDING, + scheduled=False, + ), + ctx=ctx, + uow=uow, + ) + await uow.commit() + + sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + # 4. Create contextual logger + replay_logger = LoggerConfigurator.configure_logger( + "airweave.platform.sync.replay", + dimensions={ + "sync_id": str(sync_id), + "sync_job_id": str(sync_job.id), + "organization_id": str(ctx.organization.id), + "target_connection_id": str(target_connection_id), + "mode": "replay", + }, + ) + + replay_logger.info( + f"Starting replay from ARF for sync {sync_id} → connection {target_connection_id}", + extra={"entity_count": entity_count}, + ) + + # 5. Build destination + dest_builder = DestinationBuilder(db, ctx, replay_logger) + destinations = await dest_builder.build_for_ids( + destination_ids=[target_connection_id], + collection=collection_schema, + sync_id=sync_id, + ) + + if not destinations: + raise ValueError(f"Could not create destination for connection {target_connection_id}") + + # 6. Get entity map + entity_map = await SyncFactory._get_entity_definition_map(db) + + # 7. Create ARF source + arf_source = ARFReplaySource(str(sync_id), replay_logger) + + # 8. Build context (using replay-specific builder) + context_builder = ReplayContextBuilder(db, ctx, replay_logger) + sync_context = await context_builder.build( + source=arf_source, + destinations=destinations, + sync=sync, + sync_job=sync_job_schema, + collection=collection_schema, + entity_map=entity_map, + ) + + # 9. Build pipeline (skip RawDataHandler for replay) + entity_pipeline = PipelineBuilder.build( + sync_context=sync_context, + include_raw_data_handler=False, + ) + + # 10. Create worker pool and stream + worker_pool = AsyncWorkerPool(max_workers=max_workers, logger=replay_logger) + stream = AsyncSourceStream( + source_generator=arf_source.generate_entities(), + queue_size=10000, + logger=replay_logger, + ) + + # 11. Create orchestrator + orchestrator = SyncOrchestrator( + entity_pipeline=entity_pipeline, + worker_pool=worker_pool, + stream=stream, + sync_context=sync_context, + ) + + replay_logger.info(f"Replay orchestrator created in {time.time() - init_start:.2f}s") + return orchestrator diff --git a/backend/airweave/schemas/__init__.py b/backend/airweave/schemas/__init__.py index ffe8800ce..50e41977a 100644 --- a/backend/airweave/schemas/__init__.py +++ b/backend/airweave/schemas/__init__.py @@ -136,6 +136,21 @@ SyncWithoutConnections, SyncWithSourceConnection, ) +from .sync_connection import ( + DestinationSlotInfo, + ForkDestinationRequest, + ForkDestinationResponse, + SwitchDestinationResponse, +) +from .sync_connection import ( + SyncConnection as SyncConnectionSchema, +) +from .sync_connection import ( + SyncConnectionCreate as SyncConnectionCreateSchema, +) +from .sync_connection import ( + SyncConnectionUpdate as SyncConnectionUpdateSchema, +) from .sync_cursor import ( SyncCursor, SyncCursorBase, @@ -174,3 +189,4 @@ OrganizationWithRole.model_rebuild() UserOrganization.model_rebuild() User.model_rebuild() +ForkDestinationResponse.model_rebuild() diff --git a/backend/airweave/schemas/sync_connection.py b/backend/airweave/schemas/sync_connection.py new file mode 100644 index 000000000..6ffa77ba4 --- /dev/null +++ b/backend/airweave/schemas/sync_connection.py @@ -0,0 +1,89 @@ +"""Schemas for sync connection (destination multiplexing).""" + +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel, Field + +from airweave.models.sync_connection import DestinationRole + + +class SyncConnectionBase(BaseModel): + """Base schema for sync connection.""" + + sync_id: UUID + connection_id: UUID + role: DestinationRole = DestinationRole.ACTIVE + + class Config: + """Pydantic config.""" + + from_attributes = True + + +class SyncConnectionCreate(BaseModel): + """Schema for creating a sync connection.""" + + connection_id: UUID + role: DestinationRole = DestinationRole.SHADOW + + +class SyncConnectionUpdate(BaseModel): + """Schema for updating a sync connection.""" + + role: Optional[DestinationRole] = None + + +class SyncConnection(SyncConnectionBase): + """Schema for sync connection response.""" + + id: UUID + created_at: datetime + modified_at: Optional[datetime] = None + + +class DestinationSlotInfo(BaseModel): + """Info about a destination slot for the multiplexer.""" + + slot_id: UUID = Field(..., description="Sync connection ID") + destination_connection_id: UUID = Field(..., description="Connection ID of the destination") + destination_name: str = Field(..., description="Name of the destination connection") + destination_short_name: str = Field(..., description="Short name of the destination type") + role: DestinationRole = Field(..., description="Current role (active/shadow/deprecated)") + created_at: datetime = Field(..., description="When this slot was created") + entity_count: int = Field(0, description="Entity count in ARF store") + + class Config: + """Pydantic config.""" + + from_attributes = True + + +class ForkDestinationRequest(BaseModel): + """Request schema for forking a new destination.""" + + destination_connection_id: UUID = Field( + ..., description="ID of the destination connection to add" + ) + replay_from_arf: bool = Field( + False, description="Whether to auto-replay entities from ARF store" + ) + + +class SwitchDestinationResponse(BaseModel): + """Response schema for switching active destination.""" + + status: str = "switched" + new_active_slot_id: UUID + previous_active_slot_id: Optional[UUID] = None + + +class ForkDestinationResponse(BaseModel): + """Response schema for forking a new destination.""" + + slot: "SyncConnection" + replay_job_id: Optional[UUID] = Field( + None, description="Replay job ID if replay_from_arf was requested" + ) + replay_job_status: Optional[str] = Field(None, description="Status of the replay job") diff --git a/backend/alembic/versions/add_role_to_sync_connection.py b/backend/alembic/versions/add_role_to_sync_connection.py new file mode 100644 index 000000000..2956e8fda --- /dev/null +++ b/backend/alembic/versions/add_role_to_sync_connection.py @@ -0,0 +1,38 @@ +"""Add role column to sync_connection for destination multiplexing. + +Revision ID: h1i2j3k4l5m6 +Revises: g0a9b8c7d6e5 +Create Date: 2024-12-29 12:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "h1i2j3k4l5m6" +down_revision = "g0a9b8c7d6e5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add role column to sync_connection table. + + Enables destination multiplexing for blue-green deployments: + - active: receives writes + serves queries + - shadow: receives writes only (for testing) + - deprecated: no longer in use (kept for rollback) + """ + op.add_column( + "sync_connection", + sa.Column("role", sa.String(20), nullable=False, server_default="active"), + ) + # Remove server default after setting existing rows + op.alter_column("sync_connection", "role", server_default=None) + + +def downgrade() -> None: + """Remove role column from sync_connection table.""" + op.drop_column("sync_connection", "role") From 1ab9d1c4d648ca0ebeb864d59c9224bccf200995 Mon Sep 17 00:00:00 2001 From: Siddhesh Deshpande Date: Fri, 2 Jan 2026 16:36:52 +0100 Subject: [PATCH 2/9] feat: add database-persisted SyncExecutionConfig for ARF-only capture --- .../core/source_connection_service.py | 6 +- backend/airweave/core/sync_service.py | 15 +++- backend/airweave/models/sync_job.py | 2 + .../platform/sync/actions/resolver.py | 45 ++++++++++++ backend/airweave/platform/sync/config.py | 72 +++++++++++++++++++ backend/airweave/platform/sync/context.py | 6 ++ .../platform/sync/factory/_context.py | 4 ++ .../platform/sync/factory/_destination.py | 43 ++++++++++- .../platform/sync/factory/_factory.py | 8 ++- .../platform/sync/factory/_pipeline.py | 31 ++++++-- .../platform/sync/handlers/postgres.py | 19 ++++- .../platform/sync/multiplex/multiplexer.py | 19 +++-- .../platform/temporal/activities/sync.py | 17 +++++ backend/airweave/schemas/sync_job.py | 3 +- .../add_execution_config_json_to_sync_job.py | 41 +++++++++++ 15 files changed, 312 insertions(+), 19 deletions(-) create mode 100644 backend/airweave/platform/sync/config.py create mode 100644 backend/alembic/versions/add_execution_config_json_to_sync_job.py diff --git a/backend/airweave/core/source_connection_service.py b/backend/airweave/core/source_connection_service.py index b98271a6f..8d3d2ce1c 100644 --- a/backend/airweave/core/source_connection_service.py +++ b/backend/airweave/core/source_connection_service.py @@ -1488,6 +1488,7 @@ async def run( id: UUID, ctx: ApiContext, force_full_sync: bool = False, + execution_config: Optional[Dict[str, Any]] = None, ) -> schemas.SourceConnectionJob: """Trigger a sync run for a source connection. @@ -1498,6 +1499,7 @@ async def run( force_full_sync: If True, forces a full sync with orphaned entity cleanup. Only allowed for continuous syncs (syncs with cursor data). Raises HTTPException if used on non-continuous syncs. + execution_config: Optional execution config dict to persist in DB for worker """ source_conn = await crud.source_connection.get(db, id=id, ctx=ctx) if not source_conn: @@ -1541,9 +1543,9 @@ async def run( db=db, source_connection=source_conn, ctx=ctx ) - # Trigger sync through Temporal only + # Trigger sync through Temporal only (stores execution_config in DB) sync, sync_job = await sync_service.trigger_sync_run( - db, sync_id=source_conn.sync_id, ctx=ctx + db, sync_id=source_conn.sync_id, ctx=ctx, execution_config=execution_config ) await temporal_service.run_source_connection_workflow( diff --git a/backend/airweave/core/sync_service.py b/backend/airweave/core/sync_service.py index 2323551e1..1cb4351b4 100644 --- a/backend/airweave/core/sync_service.py +++ b/backend/airweave/core/sync_service.py @@ -1,6 +1,6 @@ """Refactored sync service with Temporal-only execution.""" -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from uuid import UUID from fastapi import HTTPException @@ -15,6 +15,7 @@ from airweave.db.unit_of_work import UnitOfWork from airweave.models.sync import Sync from airweave.models.sync_job import SyncJob +from airweave.platform.sync.config import SyncExecutionConfig from airweave.platform.sync.factory import SyncFactory from airweave.platform.temporal.schedule_service import temporal_schedule_service @@ -100,6 +101,7 @@ async def run( ctx: ApiContext, access_token: Optional[str] = None, force_full_sync: bool = False, + execution_config: Optional[SyncExecutionConfig] = None, ) -> schemas.Sync: """Run a sync. @@ -113,6 +115,8 @@ async def run( access_token (Optional[str]): Optional access token to use instead of stored credentials. force_full_sync (bool): If True, forces a full sync with orphaned entity deletion. + execution_config (Optional[SyncExecutionConfig]): Optional execution config + for controlling sync behavior (destination filtering, handler toggles, etc.) Returns: ------- @@ -130,6 +134,7 @@ async def run( ctx=ctx, access_token=access_token, force_full_sync=force_full_sync, + execution_config=execution_config, ) except Exception as e: ctx.logger.error(f"Error during sync orchestrator creation: {e}") @@ -151,6 +156,7 @@ async def trigger_sync_run( db: AsyncSession, sync_id: UUID, ctx: ApiContext, + execution_config: Optional[Dict[str, Any]] = None, ) -> Tuple[schemas.Sync, schemas.SyncJob]: """Trigger a manual sync run. @@ -158,6 +164,7 @@ async def trigger_sync_run( db: Database session sync_id: Sync ID to run ctx: API context + execution_config: Optional execution config dict to persist in DB Returns: Tuple of (sync, sync_job) schemas @@ -192,7 +199,9 @@ async def trigger_sync_run( # Create sync job async with UnitOfWork(db) as uow: - sync_job = await self._create_sync_job(uow.session, sync_id, ctx, uow) + sync_job = await self._create_sync_job( + uow.session, sync_id, ctx, uow, execution_config + ) await uow.commit() await uow.session.refresh(sync_job) @@ -206,11 +215,13 @@ async def _create_sync_job( sync_id: UUID, ctx: ApiContext, uow: UnitOfWork, + execution_config: Optional[Dict[str, Any]] = None, ) -> SyncJob: """Create a sync job record.""" sync_job_in = schemas.SyncJobCreate( sync_id=sync_id, status=SyncJobStatus.PENDING, + execution_config_json=execution_config, ) return await crud.sync_job.create(db, obj_in=sync_job_in, ctx=ctx, uow=uow) diff --git a/backend/airweave/models/sync_job.py b/backend/airweave/models/sync_job.py index 27f03df55..ca851549f 100644 --- a/backend/airweave/models/sync_job.py +++ b/backend/airweave/models/sync_job.py @@ -5,6 +5,7 @@ from uuid import UUID from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, Index, Integer, String +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship from airweave.core.shared_models import SyncJobStatus @@ -35,6 +36,7 @@ class SyncJob(OrganizationBase, UserMixin): error: Mapped[Optional[str]] = mapped_column(String, nullable=True) entities_encountered: Mapped[Optional[dict]] = mapped_column(JSON, default={}) scheduled: Mapped[bool] = mapped_column(Boolean, default=False) + execution_config_json: Mapped[Optional[dict]] = mapped_column(JSONB, nullable=True) sync: Mapped["Sync"] = relationship( "Sync", diff --git a/backend/airweave/platform/sync/actions/resolver.py b/backend/airweave/platform/sync/actions/resolver.py index 2d86b7a5f..8dff37edc 100644 --- a/backend/airweave/platform/sync/actions/resolver.py +++ b/backend/airweave/platform/sync/actions/resolver.py @@ -65,6 +65,13 @@ async def resolve( Raises: SyncFailureError: If entity type not found in entity_map or missing hash """ + # Check if skip_hash_comparison is enabled + if sync_context.execution_config and sync_context.execution_config.skip_hash_comparison: + sync_context.logger.info( + "skip_hash_comparison enabled: Forcing all entities as INSERT actions" + ) + return self._force_all_inserts(entities, sync_context) + # Step 1: Separate deletions from non-deletions delete_entities, non_delete_entities = self._separate_deletions(entities) @@ -88,6 +95,44 @@ async def resolve( return batch + def _force_all_inserts( + self, + entities: List[BaseEntity], + sync_context: "SyncContext", + ) -> ActionBatch: + """Force all entities as INSERT actions (skip hash comparison). + + Used for ARF replay or when execution_config.skip_hash_comparison is True. + + Args: + entities: Entities to process + sync_context: Sync context + + Returns: + ActionBatch with all entities as inserts + """ + inserts: List[InsertAction] = [] + deletes: List[DeleteAction] = [] + + for entity in entities: + if isinstance(entity, DeletionEntity): + deletes.append(self._create_delete_action(entity, {}, sync_context)) + else: + entity_definition_id = self.resolve_entity_definition_id(entity) + if not entity_definition_id: + raise SyncFailureError( + f"Entity type {entity.__class__.__name__} not in entity_map" + ) + inserts.append(InsertAction(entity=entity, entity_definition_id=entity_definition_id)) + + return ActionBatch( + inserts=inserts, + updates=[], + keeps=[], + deletes=deletes, + existing_map={}, + ) + def resolve_entity_definition_id(self, entity: BaseEntity) -> Optional[UUID]: """Resolve entity definition ID with polymorphic fallback. diff --git a/backend/airweave/platform/sync/config.py b/backend/airweave/platform/sync/config.py new file mode 100644 index 000000000..c4245600c --- /dev/null +++ b/backend/airweave/platform/sync/config.py @@ -0,0 +1,72 @@ +"""Sync execution configuration for controlling sync behavior.""" + +from typing import List, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + + +class SyncExecutionConfig(BaseModel): + """Declarative sync execution configuration. + + Each component reads only the flags it needs - highly modular. + Config is persisted in sync_job.execution_config_json to avoid Temporal bloat. + """ + + # Destination selection + target_destinations: Optional[List[UUID]] = Field( + None, description="If set, ONLY write to these destinations" + ) + exclude_destinations: Optional[List[UUID]] = Field( + None, description="Skip these destinations" + ) + destination_strategy: str = Field( + "active_and_shadow", + description="active_only|shadow_only|all|active_and_shadow", + ) + + # Handler toggles + enable_vector_handlers: bool = Field(True, description="Enable VectorDBHandler") + enable_raw_data_handler: bool = Field(True, description="Enable RawDataHandler (ARF)") + enable_postgres_handler: bool = Field(True, description="Enable PostgresMetadataHandler") + + # Behavior flags + skip_hash_comparison: bool = Field(False, description="Force INSERT for all entities") + skip_hash_updates: bool = Field( + False, description="Don't update content_hash column" + ) + + # Performance + max_workers: int = Field(20, description="Max concurrent workers") + batch_size: int = Field(100, description="Entity batch size") + + @classmethod + def default(cls) -> "SyncExecutionConfig": + """Normal sync to active+shadow destinations.""" + return cls() + + @classmethod + def arf_capture_only(cls) -> "SyncExecutionConfig": + """Capture to ARF without vector DBs or hash updates. + + Used by multiplexer.resync_from_source() to populate ARF + without touching production vector databases. + """ + return cls( + enable_vector_handlers=False, + skip_hash_updates=True, + ) + + @classmethod + def replay_to_destination(cls, destination_id: UUID) -> "SyncExecutionConfig": + """Replay from ARF to specific destination. + + Used by multiplexer.fork() to populate a shadow destination + from ARF without re-fetching from source. + """ + return cls( + target_destinations=[destination_id], + enable_raw_data_handler=False, + skip_hash_comparison=True, + ) + diff --git a/backend/airweave/platform/sync/context.py b/backend/airweave/platform/sync/context.py index 6fee2f94e..497fe760a 100644 --- a/backend/airweave/platform/sync/context.py +++ b/backend/airweave/platform/sync/context.py @@ -1,5 +1,6 @@ """Module for sync context.""" +from typing import Optional from uuid import UUID from airweave import schemas @@ -9,6 +10,7 @@ from airweave.platform.destinations._base import BaseDestination from airweave.platform.entities._base import BaseEntity from airweave.platform.sources._base import BaseSource +from airweave.platform.sync.config import SyncExecutionConfig from airweave.platform.sync.cursor import SyncCursor from airweave.platform.sync.pipeline.entity_tracker import EntityTracker from airweave.platform.sync.state_publisher import SyncStatePublisher @@ -54,6 +56,8 @@ class SyncContext: force_full_sync: bool = False # Whether any destination supports keyword (sparse) indexing. Set once before run. has_keyword_index: bool = False + # Optional execution config for controlling sync behavior + execution_config: Optional[SyncExecutionConfig] = None # batching knobs (read by SyncOrchestrator at init) should_batch: bool = True @@ -81,6 +85,7 @@ def __init__( batch_size: int = 64, max_batch_latency_ms: int = 200, has_keyword_index: bool = False, + execution_config: Optional[SyncExecutionConfig] = None, ): """Initialize the sync context.""" self.source = source @@ -97,6 +102,7 @@ def __init__( self.guard_rail = guard_rail self.logger = logger self.force_full_sync = force_full_sync + self.execution_config = execution_config # Concurrency / batching knobs self.should_batch = should_batch diff --git a/backend/airweave/platform/sync/factory/_context.py b/backend/airweave/platform/sync/factory/_context.py index c4df49493..26a606571 100644 --- a/backend/airweave/platform/sync/factory/_context.py +++ b/backend/airweave/platform/sync/factory/_context.py @@ -4,6 +4,7 @@ """ import asyncio +from typing import Optional from sqlalchemy.ext.asyncio import AsyncSession @@ -14,6 +15,7 @@ from airweave.core.sync_cursor_service import sync_cursor_service from airweave.platform.destinations._base import BaseDestination from airweave.platform.sources._base import BaseSource +from airweave.platform.sync.config import SyncExecutionConfig from airweave.platform.sync.context import SyncContext from airweave.platform.sync.cursor import SyncCursor from airweave.platform.sync.pipeline.entity_tracker import EntityTracker @@ -52,6 +54,7 @@ async def build( collection: schemas.Collection, entity_map: dict, force_full_sync: bool = False, + execution_config: Optional[SyncExecutionConfig] = None, ) -> SyncContext: """Build a complete SyncContext.""" # 1. Load initial entity counts @@ -112,6 +115,7 @@ async def build( guard_rail=guard_rail, force_full_sync=force_full_sync, has_keyword_index=has_keyword_index, + execution_config=execution_config, ) # 8. Set cursor on source diff --git a/backend/airweave/platform/sync/factory/_destination.py b/backend/airweave/platform/sync/factory/_destination.py index d14ec3f9d..35a0923b0 100644 --- a/backend/airweave/platform/sync/factory/_destination.py +++ b/backend/airweave/platform/sync/factory/_destination.py @@ -15,6 +15,7 @@ from airweave.core.logging import ContextualLogger from airweave.platform.destinations._base import BaseDestination from airweave.platform.locator import resource_locator +from airweave.platform.sync.config import SyncExecutionConfig class DestinationBuilder: @@ -42,6 +43,7 @@ async def build( self, sync: schemas.Sync, collection: schemas.Collection, + execution_config: Optional[SyncExecutionConfig] = None, ) -> list[BaseDestination]: """Build destination instances for a sync. @@ -49,8 +51,10 @@ async def build( - ACTIVE: receives writes + serves queries - SHADOW: receives writes only (migration testing) - DEPRECATED: skipped (no writes) + + Also respects execution_config for destination filtering. """ - destination_ids = await self._get_active_ids(sync) + destination_ids = await self._get_destination_ids(sync, execution_config) destinations = await self.build_for_ids( destination_ids=destination_ids, @@ -70,6 +74,43 @@ async def build( ) return destinations + + async def _get_destination_ids( + self, + sync: schemas.Sync, + execution_config: Optional[SyncExecutionConfig], + ) -> list[UUID]: + """Get destination IDs based on roles and execution config. + + Priority: + 1. execution_config.target_destinations (if set, use only these) + 2. execution_config.exclude_destinations (filter out from active/shadow) + 3. Default: all active+shadow destinations + """ + # If target_destinations is set, use only those (highest priority) + if execution_config and execution_config.target_destinations: + self.logger.info( + f"Using target_destinations from config: {execution_config.target_destinations}" + ) + return execution_config.target_destinations + + # Get base destination IDs (active + shadow by default) + destination_ids = await self._get_active_ids(sync) + + # If exclude_destinations is set, filter them out + if execution_config and execution_config.exclude_destinations: + original_count = len(destination_ids) + destination_ids = [ + dest_id for dest_id in destination_ids + if dest_id not in execution_config.exclude_destinations + ] + excluded_count = original_count - len(destination_ids) + if excluded_count > 0: + self.logger.info( + f"Excluded {excluded_count} destination(s) based on config" + ) + + return destination_ids async def build_for_ids( self, diff --git a/backend/airweave/platform/sync/factory/_factory.py b/backend/airweave/platform/sync/factory/_factory.py index 5d7c4a53e..77bc15855 100644 --- a/backend/airweave/platform/sync/factory/_factory.py +++ b/backend/airweave/platform/sync/factory/_factory.py @@ -17,6 +17,7 @@ from airweave.core.logging import LoggerConfigurator, logger from airweave.db.init_db_native import init_db_with_entity_definitions from airweave.platform.entities._base import BaseEntity +from airweave.platform.sync.config import SyncExecutionConfig from airweave.platform.sync.factory._context import ContextBuilder from airweave.platform.sync.factory._destination import DestinationBuilder from airweave.platform.sync.factory._pipeline import PipelineBuilder @@ -55,6 +56,7 @@ async def create_orchestrator( access_token: Optional[str] = None, max_workers: Optional[int] = None, force_full_sync: bool = False, + execution_config: Optional[SyncExecutionConfig] = None, ) -> SyncOrchestrator: """Create a sync orchestrator with all required components. @@ -68,6 +70,7 @@ async def create_orchestrator( access_token: Optional token override max_workers: Max concurrent workers (default: from settings) force_full_sync: Whether to force full sync (skips cursor) + execution_config: Optional execution config for controlling sync behavior Returns: Configured SyncOrchestrator ready to run @@ -115,7 +118,9 @@ async def create_orchestrator( # 3. Build destinations dest_builder = DestinationBuilder(db, ctx, sync_logger) - destinations = await dest_builder.build(sync=sync, collection=collection) + destinations = await dest_builder.build( + sync=sync, collection=collection, execution_config=execution_config + ) # 4. Get entity map entity_map = await cls._get_entity_definition_map(db) @@ -131,6 +136,7 @@ async def create_orchestrator( collection=collection, entity_map=entity_map, force_full_sync=force_full_sync, + execution_config=execution_config, ) # 6. Build pipeline diff --git a/backend/airweave/platform/sync/factory/_pipeline.py b/backend/airweave/platform/sync/factory/_pipeline.py index 3773e1d62..99cd512bf 100644 --- a/backend/airweave/platform/sync/factory/_pipeline.py +++ b/backend/airweave/platform/sync/factory/_pipeline.py @@ -37,6 +37,7 @@ def build( destinations=sync_context.destinations, logger=sync_context.logger, include_raw_data_handler=include_raw_data_handler, + sync_context=sync_context, ) action_dispatcher = ActionDispatcher(handlers=handlers) @@ -78,10 +79,17 @@ def _create_handlers( destinations: list[BaseDestination], logger, include_raw_data_handler: bool = True, + sync_context: SyncContext = None, ) -> list[ActionHandler]: - """Create handlers based on destination requirements.""" + """Create handlers based on destination requirements and execution config.""" handlers: list[ActionHandler] = [] + # Check execution config for handler toggles + config = sync_context.execution_config if sync_context else None + enable_vector = config is None or config.enable_vector_handlers + enable_raw = config is None or config.enable_raw_data_handler + enable_postgres = config is None or config.enable_postgres_handler + vector_db_destinations: list[BaseDestination] = [] self_processing_destinations: list[BaseDestination] = [] @@ -99,7 +107,8 @@ def _create_handlers( ) vector_db_destinations.append(dest) - if vector_db_destinations: + # Only add VectorDBHandler if enabled + if vector_db_destinations and enable_vector: vector_handler = VectorDBHandler(destinations=vector_db_destinations) handlers.append(vector_handler) if logger: @@ -107,11 +116,25 @@ def _create_handlers( f"Created VectorDBHandler for {len(vector_db_destinations)} destination(s): " f"{[d.__class__.__name__ for d in vector_db_destinations]}" ) + elif vector_db_destinations and not enable_vector: + if logger: + logger.info( + f"Skipping VectorDBHandler (disabled by execution_config) for " + f"{len(vector_db_destinations)} destination(s)" + ) - if include_raw_data_handler: + # Only add RawDataHandler if enabled + if include_raw_data_handler and enable_raw: handlers.append(RawDataHandler()) + elif include_raw_data_handler and not enable_raw: + if logger: + logger.info("Skipping RawDataHandler (disabled by execution_config)") - handlers.append(PostgresMetadataHandler()) + # Only add PostgresMetadataHandler if enabled + if enable_postgres: + handlers.append(PostgresMetadataHandler()) + elif logger: + logger.info("Skipping PostgresMetadataHandler (disabled by execution_config)") if not handlers and logger: logger.warning("No destination handlers created - sync has no valid destinations") diff --git a/backend/airweave/platform/sync/handlers/postgres.py b/backend/airweave/platform/sync/handlers/postgres.py index bde0a44fd..3060cb5fe 100644 --- a/backend/airweave/platform/sync/handlers/postgres.py +++ b/backend/airweave/platform/sync/handlers/postgres.py @@ -149,10 +149,15 @@ async def _execute_inserts( if not deduped: return + # Check if hash updates should be skipped + skip_hashes = ( + sync_context.execution_config and sync_context.execution_config.skip_hash_updates + ) + # Build create objects with deterministic ordering create_objs = [] for action in deduped: - if not action.entity.airweave_system_metadata.hash: + if not skip_hashes and not action.entity.airweave_system_metadata.hash: raise SyncFailureError(f"Entity {action.entity_id} missing hash") create_objs.append( @@ -161,7 +166,7 @@ async def _execute_inserts( sync_id=sync_context.sync.id, entity_id=action.entity_id, entity_definition_id=action.entity_definition_id, - hash=action.entity.airweave_system_metadata.hash, + hash=None if skip_hashes else action.entity.airweave_system_metadata.hash, ) ) @@ -170,8 +175,9 @@ async def _execute_inserts( # Log for debugging sample_ids = [obj.entity_id for obj in create_objs[:10]] + hash_note = " (without hashes)" if skip_hashes else "" sync_context.logger.debug( - f"[Postgres] Upserting {len(create_objs)} inserts (sample: {sample_ids})" + f"[Postgres] Upserting {len(create_objs)} inserts{hash_note} (sample: {sample_ids})" ) await crud.entity.bulk_create(db, objs=create_objs, ctx=sync_context.ctx) @@ -191,6 +197,13 @@ async def _execute_updates( sync_context: Sync context db: Database session """ + # Check if hash updates should be skipped + if sync_context.execution_config and sync_context.execution_config.skip_hash_updates: + sync_context.logger.info( + f"[Postgres] Skipping hash updates for {len(actions)} entities (skip_hash_updates=True)" + ) + return + update_pairs = [] for action in actions: diff --git a/backend/airweave/platform/sync/multiplex/multiplexer.py b/backend/airweave/platform/sync/multiplex/multiplexer.py index 559b90902..6b45947b9 100644 --- a/backend/airweave/platform/sync/multiplex/multiplexer.py +++ b/backend/airweave/platform/sync/multiplex/multiplexer.py @@ -293,10 +293,10 @@ async def resync_from_source( self, sync_id: UUID, ) -> schemas.SyncJob: - """Trigger full sync from source to refresh ARF. + """Trigger ARF-only sync from source to refresh ARF store. Ensures ARF is up-to-date before forking to a new destination. - Uses force_full_sync=True to bypass cursor and get all entities. + Uses ARF-only execution config to skip vector DB writes and hash updates. Args: sync_id: Sync ID @@ -305,6 +305,7 @@ async def resync_from_source( SyncJob for tracking progress """ from airweave.core import source_connection_service + from airweave.platform.sync.config import SyncExecutionConfig # 1. Validate sync exists sync = await crud.sync.get(self.db, id=sync_id, ctx=self.ctx, with_connections=False) @@ -320,17 +321,25 @@ async def resync_from_source( status_code=404, detail=f"No source connection found for sync {sync_id}" ) + # 3. Create ARF-only execution config + config = SyncExecutionConfig.arf_capture_only() + self.logger.info( - "Triggering full resync from source for ARF refresh", - extra={"sync_id": str(sync_id), "source_connection_id": str(source_conn.id)}, + "Triggering ARF-only resync from source (no vector DB writes)", + extra={ + "sync_id": str(sync_id), + "source_connection_id": str(source_conn.id), + "execution_config": config.model_dump(), + }, ) - # 3. Trigger via existing service (force_full_sync=True) + # 4. Trigger via existing service with ARF-only config job = await source_connection_service.run( self.db, id=source_conn.id, ctx=self.ctx, force_full_sync=True, + execution_config=config.model_dump(), ) # Convert SourceConnectionJob to SyncJob diff --git a/backend/airweave/platform/temporal/activities/sync.py b/backend/airweave/platform/temporal/activities/sync.py index 236ea13b2..9b6a1e7a9 100644 --- a/backend/airweave/platform/temporal/activities/sync.py +++ b/backend/airweave/platform/temporal/activities/sync.py @@ -21,8 +21,24 @@ async def _run_sync_task( force_full_sync=False, ): """Run the actual sync service.""" + from airweave import crud from airweave.core.exceptions import NotFoundException from airweave.core.sync_service import sync_service + from airweave.db.session import get_db_context + from airweave.platform.sync.config import SyncExecutionConfig + + # Refetch sync_job from DB to get execution_config_json + execution_config = None + try: + async with get_db_context() as db: + sync_job_model = await crud.sync_job.get(db, id=sync_job.id) + if sync_job_model and sync_job_model.execution_config_json: + execution_config = SyncExecutionConfig(**sync_job_model.execution_config_json) + ctx.logger.info( + f"Loaded execution config from DB: {sync_job_model.execution_config_json}" + ) + except Exception as e: + ctx.logger.warning(f"Failed to load execution config from DB: {e}") try: return await sync_service.run( @@ -33,6 +49,7 @@ async def _run_sync_task( ctx=ctx, access_token=access_token, force_full_sync=force_full_sync, + execution_config=execution_config, ) except NotFoundException as e: # Check if this is the specific "Source connection record not found" error diff --git a/backend/airweave/schemas/sync_job.py b/backend/airweave/schemas/sync_job.py index 6e9e02ca2..745533fc9 100644 --- a/backend/airweave/schemas/sync_job.py +++ b/backend/airweave/schemas/sync_job.py @@ -6,7 +6,7 @@ """ from datetime import datetime -from typing import Optional +from typing import Any, Dict, Optional from uuid import UUID from pydantic import BaseModel, EmailStr, Field @@ -31,6 +31,7 @@ class SyncJobBase(BaseModel): failed_at: Optional[datetime] = None error: Optional[str] = None access_token: Optional[str] = None + execution_config_json: Optional[Dict[str, Any]] = None class Config: """Pydantic config for SyncJobBase.""" diff --git a/backend/alembic/versions/add_execution_config_json_to_sync_job.py b/backend/alembic/versions/add_execution_config_json_to_sync_job.py new file mode 100644 index 000000000..c0c41b81f --- /dev/null +++ b/backend/alembic/versions/add_execution_config_json_to_sync_job.py @@ -0,0 +1,41 @@ +"""Add execution_config_json column to sync_job for config-driven sync execution. + +Revision ID: n7o8p9q0r1s2 +Revises: h1i2j3k4l5m6 +Create Date: 2025-01-02 12:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = "n7o8p9q0r1s2" +down_revision = "h1i2j3k4l5m6" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add execution_config_json column to sync_job table. + + Enables config-driven sync execution without bloating Temporal contract. + Worker refetches SyncJob from DB to get execution config. + + Config controls: + - Destination filtering (target/exclude) + - Handler toggles (vector/ARF/postgres) + - Behavior flags (skip hash comparison/updates) + """ + op.add_column( + "sync_job", + sa.Column("execution_config_json", postgresql.JSONB(), nullable=True), + ) + + +def downgrade() -> None: + """Remove execution_config_json column from sync_job table.""" + op.drop_column("sync_job", "execution_config_json") + From 237988f6a2e68488f942dcc303e9696b5f9d65a6 Mon Sep 17 00:00:00 2001 From: Siddhesh Deshpande Date: Fri, 2 Jan 2026 17:18:39 +0100 Subject: [PATCH 3/9] add source connection id to manifest --- backend/airweave/core/source_connection_service.py | 2 +- backend/airweave/platform/sync/context.py | 3 +++ backend/airweave/platform/sync/factory/_context.py | 1 + backend/airweave/platform/sync/raw_data.py | 2 ++ 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/airweave/core/source_connection_service.py b/backend/airweave/core/source_connection_service.py index 8d3d2ce1c..59ce6d4b7 100644 --- a/backend/airweave/core/source_connection_service.py +++ b/backend/airweave/core/source_connection_service.py @@ -1,7 +1,7 @@ """Clean source connection service with auth method inference.""" from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from uuid import UUID if TYPE_CHECKING: diff --git a/backend/airweave/platform/sync/context.py b/backend/airweave/platform/sync/context.py index 497fe760a..af3556786 100644 --- a/backend/airweave/platform/sync/context.py +++ b/backend/airweave/platform/sync/context.py @@ -48,6 +48,7 @@ class SyncContext: cursor: SyncCursor collection: schemas.Collection connection: schemas.Connection + source_connection_id: Optional[UUID] = None entity_map: dict[type[BaseEntity], UUID] ctx: ApiContext guard_rail: GuardRailService @@ -79,6 +80,7 @@ def __init__( ctx: ApiContext, guard_rail: GuardRailService, logger: ContextualLogger, + source_connection_id: Optional[UUID] = None, force_full_sync: bool = False, # Micro-batching controls should_batch: bool = True, @@ -97,6 +99,7 @@ def __init__( self.cursor = cursor self.collection = collection self.connection = connection + self.source_connection_id = source_connection_id self.entity_map = entity_map self.ctx = ctx self.guard_rail = guard_rail diff --git a/backend/airweave/platform/sync/factory/_context.py b/backend/airweave/platform/sync/factory/_context.py index 26a606571..34b26d5ec 100644 --- a/backend/airweave/platform/sync/factory/_context.py +++ b/backend/airweave/platform/sync/factory/_context.py @@ -106,6 +106,7 @@ async def build( sync_job=sync_job, collection=collection, connection=None, + source_connection_id=source_connection_data.get("source_connection_id"), entity_tracker=entity_tracker, state_publisher=state_publisher, cursor=cursor, diff --git a/backend/airweave/platform/sync/raw_data.py b/backend/airweave/platform/sync/raw_data.py index ab2aa002d..232587d8a 100644 --- a/backend/airweave/platform/sync/raw_data.py +++ b/backend/airweave/platform/sync/raw_data.py @@ -51,6 +51,7 @@ class SyncManifest(BaseModel): sync_id: str source_short_name: str + source_connection_id: Optional[str] = None collection_id: str collection_readable_id: str organization_id: str @@ -450,6 +451,7 @@ async def _update_manifest( manifest = SyncManifest( sync_id=sync_id, source_short_name=self._get_source_short_name(sync_context), + source_connection_id=str(sync_context.source_connection_id) if sync_context.source_connection_id else None, collection_id=str(sync_context.collection.id), collection_readable_id=sync_context.collection.readable_id, organization_id=str(sync_context.collection.organization_id), From 57e91b355899d033ad84e4c820c74cbdcc6b2c06 Mon Sep 17 00:00:00 2001 From: Siddhesh Deshpande Date: Fri, 2 Jan 2026 17:46:50 +0100 Subject: [PATCH 4/9] Revert "add source connection id to manifest" This reverts commit 237988f6a2e68488f942dcc303e9696b5f9d65a6. --- backend/airweave/core/source_connection_service.py | 2 +- backend/airweave/platform/sync/context.py | 3 --- backend/airweave/platform/sync/factory/_context.py | 1 - backend/airweave/platform/sync/raw_data.py | 2 -- 4 files changed, 1 insertion(+), 7 deletions(-) diff --git a/backend/airweave/core/source_connection_service.py b/backend/airweave/core/source_connection_service.py index 59ce6d4b7..8d3d2ce1c 100644 --- a/backend/airweave/core/source_connection_service.py +++ b/backend/airweave/core/source_connection_service.py @@ -1,7 +1,7 @@ """Clean source connection service with auth method inference.""" from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple from uuid import UUID if TYPE_CHECKING: diff --git a/backend/airweave/platform/sync/context.py b/backend/airweave/platform/sync/context.py index af3556786..497fe760a 100644 --- a/backend/airweave/platform/sync/context.py +++ b/backend/airweave/platform/sync/context.py @@ -48,7 +48,6 @@ class SyncContext: cursor: SyncCursor collection: schemas.Collection connection: schemas.Connection - source_connection_id: Optional[UUID] = None entity_map: dict[type[BaseEntity], UUID] ctx: ApiContext guard_rail: GuardRailService @@ -80,7 +79,6 @@ def __init__( ctx: ApiContext, guard_rail: GuardRailService, logger: ContextualLogger, - source_connection_id: Optional[UUID] = None, force_full_sync: bool = False, # Micro-batching controls should_batch: bool = True, @@ -99,7 +97,6 @@ def __init__( self.cursor = cursor self.collection = collection self.connection = connection - self.source_connection_id = source_connection_id self.entity_map = entity_map self.ctx = ctx self.guard_rail = guard_rail diff --git a/backend/airweave/platform/sync/factory/_context.py b/backend/airweave/platform/sync/factory/_context.py index 34b26d5ec..26a606571 100644 --- a/backend/airweave/platform/sync/factory/_context.py +++ b/backend/airweave/platform/sync/factory/_context.py @@ -106,7 +106,6 @@ async def build( sync_job=sync_job, collection=collection, connection=None, - source_connection_id=source_connection_data.get("source_connection_id"), entity_tracker=entity_tracker, state_publisher=state_publisher, cursor=cursor, diff --git a/backend/airweave/platform/sync/raw_data.py b/backend/airweave/platform/sync/raw_data.py index 232587d8a..ab2aa002d 100644 --- a/backend/airweave/platform/sync/raw_data.py +++ b/backend/airweave/platform/sync/raw_data.py @@ -51,7 +51,6 @@ class SyncManifest(BaseModel): sync_id: str source_short_name: str - source_connection_id: Optional[str] = None collection_id: str collection_readable_id: str organization_id: str @@ -451,7 +450,6 @@ async def _update_manifest( manifest = SyncManifest( sync_id=sync_id, source_short_name=self._get_source_short_name(sync_context), - source_connection_id=str(sync_context.source_connection_id) if sync_context.source_connection_id else None, collection_id=str(sync_context.collection.id), collection_readable_id=sync_context.collection.readable_id, organization_id=str(sync_context.collection.organization_id), From edc425c51fa97a9ad879c1317c3e19000d1cb888 Mon Sep 17 00:00:00 2001 From: Siddhesh Deshpande Date: Fri, 2 Jan 2026 17:57:54 +0100 Subject: [PATCH 5/9] Add execution_config_json column to sync_job table Enables database-first approach for sync execution configuration. Config is persisted in DB and refetched by worker to avoid Temporal bloat. --- ...e_add_execution_config_json_to_sync_job.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 backend/alembic/versions/9ff3dddc32fe_add_execution_config_json_to_sync_job.py diff --git a/backend/alembic/versions/9ff3dddc32fe_add_execution_config_json_to_sync_job.py b/backend/alembic/versions/9ff3dddc32fe_add_execution_config_json_to_sync_job.py new file mode 100644 index 000000000..8c6821742 --- /dev/null +++ b/backend/alembic/versions/9ff3dddc32fe_add_execution_config_json_to_sync_job.py @@ -0,0 +1,29 @@ +"""add_execution_config_json_to_sync_job + +Revision ID: 9ff3dddc32fe +Revises: h1i2j3k4l5m6 +Create Date: 2026-01-02 17:27:16.572060 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '9ff3dddc32fe' +down_revision = 'h1i2j3k4l5m6' +branch_labels = None +depends_on = None + + +def upgrade(): + """Add execution_config_json column to sync_job table.""" + op.add_column( + "sync_job", + sa.Column("execution_config_json", postgresql.JSONB(), nullable=True), + ) + + +def downgrade(): + """Remove execution_config_json column from sync_job table.""" + op.drop_column("sync_job", "execution_config_json") From fb1a4a61c340533c147f87b32db758dea4e3473d Mon Sep 17 00:00:00 2001 From: Siddhesh Deshpande Date: Fri, 2 Jan 2026 18:34:47 +0100 Subject: [PATCH 6/9] to arf resync --- .../core/source_connection_service.py | 2 +- backend/airweave/core/sync_service.py | 6 ++- backend/airweave/platform/sync/config.py | 5 +++ .../platform/sync/handlers/postgres.py | 2 +- .../platform/sync/multiplex/multiplexer.py | 4 +- .../airweave/platform/sync/orchestrator.py | 7 ++++ .../platform/temporal/activities/sync.py | 2 +- .../add_execution_config_json_to_sync_job.py | 41 ------------------- 8 files changed, 22 insertions(+), 47 deletions(-) delete mode 100644 backend/alembic/versions/add_execution_config_json_to_sync_job.py diff --git a/backend/airweave/core/source_connection_service.py b/backend/airweave/core/source_connection_service.py index 8d3d2ce1c..59ce6d4b7 100644 --- a/backend/airweave/core/source_connection_service.py +++ b/backend/airweave/core/source_connection_service.py @@ -1,7 +1,7 @@ """Clean source connection service with auth method inference.""" from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from uuid import UUID if TYPE_CHECKING: diff --git a/backend/airweave/core/sync_service.py b/backend/airweave/core/sync_service.py index 1cb4351b4..e3c2e4833 100644 --- a/backend/airweave/core/sync_service.py +++ b/backend/airweave/core/sync_service.py @@ -218,13 +218,17 @@ async def _create_sync_job( execution_config: Optional[Dict[str, Any]] = None, ) -> SyncJob: """Create a sync job record.""" + ctx.logger.info(f"Creating sync job with execution_config: {execution_config}") sync_job_in = schemas.SyncJobCreate( sync_id=sync_id, status=SyncJobStatus.PENDING, execution_config_json=execution_config, ) + ctx.logger.info(f"SyncJobCreate schema: {sync_job_in.model_dump()}") - return await crud.sync_job.create(db, obj_in=sync_job_in, ctx=ctx, uow=uow) + result = await crud.sync_job.create(db, obj_in=sync_job_in, ctx=ctx, uow=uow) + ctx.logger.info(f"Created sync job with execution_config_json: {result.execution_config_json}") + return result async def list_sync_jobs( self, diff --git a/backend/airweave/platform/sync/config.py b/backend/airweave/platform/sync/config.py index c4245600c..6e68ca99d 100644 --- a/backend/airweave/platform/sync/config.py +++ b/backend/airweave/platform/sync/config.py @@ -35,6 +35,9 @@ class SyncExecutionConfig(BaseModel): skip_hash_updates: bool = Field( False, description="Don't update content_hash column" ) + skip_cursor_updates: bool = Field( + False, description="Don't save cursor progress (for ARF-only syncs)" + ) # Performance max_workers: int = Field(20, description="Max concurrent workers") @@ -54,7 +57,9 @@ def arf_capture_only(cls) -> "SyncExecutionConfig": """ return cls( enable_vector_handlers=False, + enable_postgres_handler=False, skip_hash_updates=True, + skip_cursor_updates=True, ) @classmethod diff --git a/backend/airweave/platform/sync/handlers/postgres.py b/backend/airweave/platform/sync/handlers/postgres.py index 3060cb5fe..686bb124a 100644 --- a/backend/airweave/platform/sync/handlers/postgres.py +++ b/backend/airweave/platform/sync/handlers/postgres.py @@ -166,7 +166,7 @@ async def _execute_inserts( sync_id=sync_context.sync.id, entity_id=action.entity_id, entity_definition_id=action.entity_definition_id, - hash=None if skip_hashes else action.entity.airweave_system_metadata.hash, + hash="" if skip_hashes else action.entity.airweave_system_metadata.hash, ) ) diff --git a/backend/airweave/platform/sync/multiplex/multiplexer.py b/backend/airweave/platform/sync/multiplex/multiplexer.py index 6b45947b9..e534c01a6 100644 --- a/backend/airweave/platform/sync/multiplex/multiplexer.py +++ b/backend/airweave/platform/sync/multiplex/multiplexer.py @@ -304,7 +304,7 @@ async def resync_from_source( Returns: SyncJob for tracking progress """ - from airweave.core import source_connection_service + from airweave.core.source_connection_service import source_connection_service from airweave.platform.sync.config import SyncExecutionConfig # 1. Validate sync exists @@ -334,11 +334,11 @@ async def resync_from_source( ) # 4. Trigger via existing service with ARF-only config + # Note: Don't use force_full_sync for non-continuous sources (always full anyway) job = await source_connection_service.run( self.db, id=source_conn.id, ctx=self.ctx, - force_full_sync=True, execution_config=config.model_dump(), ) diff --git a/backend/airweave/platform/sync/orchestrator.py b/backend/airweave/platform/sync/orchestrator.py index 91d21cb4e..980967a07 100644 --- a/backend/airweave/platform/sync/orchestrator.py +++ b/backend/airweave/platform/sync/orchestrator.py @@ -541,6 +541,13 @@ async def _complete_sync(self) -> None: async def _save_cursor_data(self) -> None: """Save cursor data to database if it exists.""" + # Skip cursor updates if configured (e.g., for ARF-only syncs) + if self.sync_context.execution_config and self.sync_context.execution_config.skip_cursor_updates: + self.sync_context.logger.info( + "⏭️ Skipping cursor update (disabled by execution_config)" + ) + return + if not hasattr(self.sync_context, "cursor") or not self.sync_context.cursor.cursor_data: if self.sync_context.force_full_sync: self.sync_context.logger.info( diff --git a/backend/airweave/platform/temporal/activities/sync.py b/backend/airweave/platform/temporal/activities/sync.py index 9b6a1e7a9..07e48ae40 100644 --- a/backend/airweave/platform/temporal/activities/sync.py +++ b/backend/airweave/platform/temporal/activities/sync.py @@ -31,7 +31,7 @@ async def _run_sync_task( execution_config = None try: async with get_db_context() as db: - sync_job_model = await crud.sync_job.get(db, id=sync_job.id) + sync_job_model = await crud.sync_job.get(db, id=sync_job.id, ctx=ctx) if sync_job_model and sync_job_model.execution_config_json: execution_config = SyncExecutionConfig(**sync_job_model.execution_config_json) ctx.logger.info( diff --git a/backend/alembic/versions/add_execution_config_json_to_sync_job.py b/backend/alembic/versions/add_execution_config_json_to_sync_job.py deleted file mode 100644 index c0c41b81f..000000000 --- a/backend/alembic/versions/add_execution_config_json_to_sync_job.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Add execution_config_json column to sync_job for config-driven sync execution. - -Revision ID: n7o8p9q0r1s2 -Revises: h1i2j3k4l5m6 -Create Date: 2025-01-02 12:00:00.000000 - -""" - -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - - -# revision identifiers, used by Alembic. -revision = "n7o8p9q0r1s2" -down_revision = "h1i2j3k4l5m6" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - """Add execution_config_json column to sync_job table. - - Enables config-driven sync execution without bloating Temporal contract. - Worker refetches SyncJob from DB to get execution config. - - Config controls: - - Destination filtering (target/exclude) - - Handler toggles (vector/ARF/postgres) - - Behavior flags (skip hash comparison/updates) - """ - op.add_column( - "sync_job", - sa.Column("execution_config_json", postgresql.JSONB(), nullable=True), - ) - - -def downgrade() -> None: - """Remove execution_config_json column from sync_job table.""" - op.drop_column("sync_job", "execution_config_json") - From 01fda9fbe3576479a12ca7e5b1eb559d1f248598 Mon Sep 17 00:00:00 2001 From: Siddhesh Deshpande Date: Fri, 2 Jan 2026 18:50:31 +0100 Subject: [PATCH 7/9] arf replay: ruff format --- backend/airweave/core/sync_service.py | 8 +++---- .../platform/sync/actions/resolver.py | 4 +++- backend/airweave/platform/sync/config.py | 9 ++------ .../platform/sync/factory/_destination.py | 23 +++++++++---------- .../airweave/platform/sync/orchestrator.py | 11 +++++---- 5 files changed, 26 insertions(+), 29 deletions(-) diff --git a/backend/airweave/core/sync_service.py b/backend/airweave/core/sync_service.py index e3c2e4833..967f2855a 100644 --- a/backend/airweave/core/sync_service.py +++ b/backend/airweave/core/sync_service.py @@ -199,9 +199,7 @@ async def trigger_sync_run( # Create sync job async with UnitOfWork(db) as uow: - sync_job = await self._create_sync_job( - uow.session, sync_id, ctx, uow, execution_config - ) + sync_job = await self._create_sync_job(uow.session, sync_id, ctx, uow, execution_config) await uow.commit() await uow.session.refresh(sync_job) @@ -227,7 +225,9 @@ async def _create_sync_job( ctx.logger.info(f"SyncJobCreate schema: {sync_job_in.model_dump()}") result = await crud.sync_job.create(db, obj_in=sync_job_in, ctx=ctx, uow=uow) - ctx.logger.info(f"Created sync job with execution_config_json: {result.execution_config_json}") + ctx.logger.info( + f"Created sync job with execution_config_json: {result.execution_config_json}" + ) return result async def list_sync_jobs( diff --git a/backend/airweave/platform/sync/actions/resolver.py b/backend/airweave/platform/sync/actions/resolver.py index 8dff37edc..5838bcbb9 100644 --- a/backend/airweave/platform/sync/actions/resolver.py +++ b/backend/airweave/platform/sync/actions/resolver.py @@ -123,7 +123,9 @@ def _force_all_inserts( raise SyncFailureError( f"Entity type {entity.__class__.__name__} not in entity_map" ) - inserts.append(InsertAction(entity=entity, entity_definition_id=entity_definition_id)) + inserts.append( + InsertAction(entity=entity, entity_definition_id=entity_definition_id) + ) return ActionBatch( inserts=inserts, diff --git a/backend/airweave/platform/sync/config.py b/backend/airweave/platform/sync/config.py index 6e68ca99d..55711d02b 100644 --- a/backend/airweave/platform/sync/config.py +++ b/backend/airweave/platform/sync/config.py @@ -17,9 +17,7 @@ class SyncExecutionConfig(BaseModel): target_destinations: Optional[List[UUID]] = Field( None, description="If set, ONLY write to these destinations" ) - exclude_destinations: Optional[List[UUID]] = Field( - None, description="Skip these destinations" - ) + exclude_destinations: Optional[List[UUID]] = Field(None, description="Skip these destinations") destination_strategy: str = Field( "active_and_shadow", description="active_only|shadow_only|all|active_and_shadow", @@ -32,9 +30,7 @@ class SyncExecutionConfig(BaseModel): # Behavior flags skip_hash_comparison: bool = Field(False, description="Force INSERT for all entities") - skip_hash_updates: bool = Field( - False, description="Don't update content_hash column" - ) + skip_hash_updates: bool = Field(False, description="Don't update content_hash column") skip_cursor_updates: bool = Field( False, description="Don't save cursor progress (for ARF-only syncs)" ) @@ -74,4 +70,3 @@ def replay_to_destination(cls, destination_id: UUID) -> "SyncExecutionConfig": enable_raw_data_handler=False, skip_hash_comparison=True, ) - diff --git a/backend/airweave/platform/sync/factory/_destination.py b/backend/airweave/platform/sync/factory/_destination.py index 35a0923b0..978ba4279 100644 --- a/backend/airweave/platform/sync/factory/_destination.py +++ b/backend/airweave/platform/sync/factory/_destination.py @@ -51,7 +51,7 @@ async def build( - ACTIVE: receives writes + serves queries - SHADOW: receives writes only (migration testing) - DEPRECATED: skipped (no writes) - + Also respects execution_config for destination filtering. """ destination_ids = await self._get_destination_ids(sync, execution_config) @@ -74,14 +74,14 @@ async def build( ) return destinations - + async def _get_destination_ids( - self, - sync: schemas.Sync, + self, + sync: schemas.Sync, execution_config: Optional[SyncExecutionConfig], ) -> list[UUID]: """Get destination IDs based on roles and execution config. - + Priority: 1. execution_config.target_destinations (if set, use only these) 2. execution_config.exclude_destinations (filter out from active/shadow) @@ -93,23 +93,22 @@ async def _get_destination_ids( f"Using target_destinations from config: {execution_config.target_destinations}" ) return execution_config.target_destinations - + # Get base destination IDs (active + shadow by default) destination_ids = await self._get_active_ids(sync) - + # If exclude_destinations is set, filter them out if execution_config and execution_config.exclude_destinations: original_count = len(destination_ids) destination_ids = [ - dest_id for dest_id in destination_ids + dest_id + for dest_id in destination_ids if dest_id not in execution_config.exclude_destinations ] excluded_count = original_count - len(destination_ids) if excluded_count > 0: - self.logger.info( - f"Excluded {excluded_count} destination(s) based on config" - ) - + self.logger.info(f"Excluded {excluded_count} destination(s) based on config") + return destination_ids async def build_for_ids( diff --git a/backend/airweave/platform/sync/orchestrator.py b/backend/airweave/platform/sync/orchestrator.py index 980967a07..9fe039355 100644 --- a/backend/airweave/platform/sync/orchestrator.py +++ b/backend/airweave/platform/sync/orchestrator.py @@ -542,12 +542,13 @@ async def _complete_sync(self) -> None: async def _save_cursor_data(self) -> None: """Save cursor data to database if it exists.""" # Skip cursor updates if configured (e.g., for ARF-only syncs) - if self.sync_context.execution_config and self.sync_context.execution_config.skip_cursor_updates: - self.sync_context.logger.info( - "⏭️ Skipping cursor update (disabled by execution_config)" - ) + if ( + self.sync_context.execution_config + and self.sync_context.execution_config.skip_cursor_updates + ): + self.sync_context.logger.info("⏭️ Skipping cursor update (disabled by execution_config)") return - + if not hasattr(self.sync_context, "cursor") or not self.sync_context.cursor.cursor_data: if self.sync_context.force_full_sync: self.sync_context.logger.info( From 7252404b857512b4f7f0aa0169946d761ff16a68 Mon Sep 17 00:00:00 2001 From: Siddhesh Deshpande Date: Sat, 3 Jan 2026 10:08:25 +0100 Subject: [PATCH 8/9] fix(sync): add execution config validation and skip_cursor_load for ARF backfill --- backend/airweave/crud/crud_sync_connection.py | 8 +- backend/airweave/platform/sync/config.py | 78 ++++++++++++++++++- .../platform/sync/factory/_context.py | 6 ++ .../platform/sync/factory/_pipeline.py | 2 + .../platform/sync/multiplex/multiplexer.py | 4 +- 5 files changed, 90 insertions(+), 8 deletions(-) diff --git a/backend/airweave/crud/crud_sync_connection.py b/backend/airweave/crud/crud_sync_connection.py index 067d39328..1c8cbd88d 100644 --- a/backend/airweave/crud/crud_sync_connection.py +++ b/backend/airweave/crud/crud_sync_connection.py @@ -176,7 +176,7 @@ async def create( db.add(db_obj) if uow: - await uow.flush() + pass else: await db.commit() await db.refresh(db_obj) @@ -205,7 +205,7 @@ async def update_role( await db.execute(update(self.model).where(self.model.id == id).values(role=role.value)) if uow: - await uow.flush() + pass else: await db.commit() @@ -244,7 +244,7 @@ async def bulk_update_role( ) if uow: - await uow.flush() + pass else: await db.commit() @@ -274,7 +274,7 @@ async def remove( await db.delete(db_obj) if uow: - await uow.flush() + pass else: await db.commit() diff --git a/backend/airweave/platform/sync/config.py b/backend/airweave/platform/sync/config.py index 55711d02b..c46e9ea23 100644 --- a/backend/airweave/platform/sync/config.py +++ b/backend/airweave/platform/sync/config.py @@ -1,9 +1,10 @@ """Sync execution configuration for controlling sync behavior.""" +import warnings from typing import List, Optional from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class SyncExecutionConfig(BaseModel): @@ -31,6 +32,7 @@ class SyncExecutionConfig(BaseModel): # Behavior flags skip_hash_comparison: bool = Field(False, description="Force INSERT for all entities") skip_hash_updates: bool = Field(False, description="Don't update content_hash column") + skip_cursor_load: bool = Field(False, description="Don't load cursor (fetch all entities)") skip_cursor_updates: bool = Field( False, description="Don't save cursor progress (for ARF-only syncs)" ) @@ -39,6 +41,54 @@ class SyncExecutionConfig(BaseModel): max_workers: int = Field(20, description="Max concurrent workers") batch_size: int = Field(100, description="Entity batch size") + @model_validator(mode="after") + def validate_config_logic(self): + """Validate that config combinations make sense.""" + + # 1. Validate destination_strategy is a known value + valid_strategies = {"active_only", "shadow_only", "all", "active_and_shadow"} + if self.destination_strategy not in valid_strategies: + raise ValueError( + f"destination_strategy must be one of {valid_strategies}, " + f"got '{self.destination_strategy}'" + ) + + # 2. Warn if target_destinations overrides destination_strategy + if self.target_destinations and self.destination_strategy != "active_and_shadow": + warnings.warn( + f"destination_strategy='{self.destination_strategy}' is ignored when " + f"target_destinations is set. Explicitly listing destinations takes precedence.", + stacklevel=2, + ) + + # 3. Detect conflicts between target and exclude destinations + if self.target_destinations and self.exclude_destinations: + overlap = set(self.target_destinations) & set(self.exclude_destinations) + if overlap: + raise ValueError( + f"Cannot have same destination in both target_destinations and " + f"exclude_destinations: {overlap}" + ) + + # 4. Warn about replay configs that re-write to ARF + if self.target_destinations and self.enable_raw_data_handler: + warnings.warn( + "Replay to specific destination typically disables raw_data_handler. " + "You're writing the same data to ARF again. Is this intended?", + stacklevel=2, + ) + + # 5. Warn about unusual cursor skip combination + if self.skip_cursor_updates and not self.skip_hash_updates: + warnings.warn( + "skip_cursor_updates=True but skip_hash_updates=False. " + "This means next sync will use old cursor but compare new hashes. " + "Typically both are skipped together (e.g., arf_capture_only).", + stacklevel=2, + ) + + return self + @classmethod def default(cls) -> "SyncExecutionConfig": """Normal sync to active+shadow destinations.""" @@ -49,12 +99,14 @@ def arf_capture_only(cls) -> "SyncExecutionConfig": """Capture to ARF without vector DBs or hash updates. Used by multiplexer.resync_from_source() to populate ARF - without touching production vector databases. + without touching production vector databases. Fetches all entities + (skips cursor) to ensure complete ARF backfill. """ return cls( enable_vector_handlers=False, enable_postgres_handler=False, skip_hash_updates=True, + skip_cursor_load=True, skip_cursor_updates=True, ) @@ -70,3 +122,25 @@ def replay_to_destination(cls, destination_id: UUID) -> "SyncExecutionConfig": enable_raw_data_handler=False, skip_hash_comparison=True, ) + + @classmethod + def dry_run(cls) -> "SyncExecutionConfig": + """Validate source without writing anywhere. + + Use cases: + - Test source credentials + - Validate entity schemas + - Count entities before full sync + - Performance testing + + Note: All handlers disabled, but sync will still execute + entity fetching and transformation logic. + """ + return cls( + enable_vector_handlers=False, + enable_raw_data_handler=False, + enable_postgres_handler=False, + skip_hash_updates=True, + skip_cursor_load=True, + skip_cursor_updates=True, + ) diff --git a/backend/airweave/platform/sync/factory/_context.py b/backend/airweave/platform/sync/factory/_context.py index 26a606571..3a1606a24 100644 --- a/backend/airweave/platform/sync/factory/_context.py +++ b/backend/airweave/platform/sync/factory/_context.py @@ -93,6 +93,7 @@ async def build( sync=sync, source_connection_data=source_connection_data, force_full_sync=force_full_sync, + execution_config=execution_config, ) # 6. Detect keyword index capability @@ -129,6 +130,7 @@ async def _create_cursor( sync: schemas.Sync, source_connection_data: dict, force_full_sync: bool, + execution_config: Optional[SyncExecutionConfig] = None, ) -> SyncCursor: """Create cursor with optional data loading.""" cursor_schema = None @@ -143,6 +145,10 @@ async def _create_cursor( "🔄 FORCE FULL SYNC: Skipping cursor data to ensure all entities are fetched " "for accurate orphaned entity cleanup. Will still track cursor for next sync." ) + elif execution_config and execution_config.skip_cursor_load: + self.logger.info( + "🔄 SKIP CURSOR LOAD: Fetching all entities (execution_config.skip_cursor_load=True)" + ) else: cursor_data = await sync_cursor_service.get_cursor_data( db=self.db, sync_id=sync.id, ctx=self.ctx diff --git a/backend/airweave/platform/sync/factory/_pipeline.py b/backend/airweave/platform/sync/factory/_pipeline.py index 99cd512bf..48e685f5d 100644 --- a/backend/airweave/platform/sync/factory/_pipeline.py +++ b/backend/airweave/platform/sync/factory/_pipeline.py @@ -91,6 +91,8 @@ def _create_handlers( enable_postgres = config is None or config.enable_postgres_handler vector_db_destinations: list[BaseDestination] = [] + # TODO(fschmetz/orhanrauf): Self-processing destinations for Vespa - destinations that handle + # their own chunking/embedding. Handler implementation coming soon. self_processing_destinations: list[BaseDestination] = [] for dest in destinations: diff --git a/backend/airweave/platform/sync/multiplex/multiplexer.py b/backend/airweave/platform/sync/multiplex/multiplexer.py index e534c01a6..2e54c15c0 100644 --- a/backend/airweave/platform/sync/multiplex/multiplexer.py +++ b/backend/airweave/platform/sync/multiplex/multiplexer.py @@ -321,7 +321,8 @@ async def resync_from_source( status_code=404, detail=f"No source connection found for sync {sync_id}" ) - # 3. Create ARF-only execution config + # 3. Create ARF-only execution config with cursor skip + # skip_cursor_load ensures we fetch ALL data (for ARF backfill) regardless of source type config = SyncExecutionConfig.arf_capture_only() self.logger.info( @@ -334,7 +335,6 @@ async def resync_from_source( ) # 4. Trigger via existing service with ARF-only config - # Note: Don't use force_full_sync for non-continuous sources (always full anyway) job = await source_connection_service.run( self.db, id=source_conn.id, From 03b37661b75b5a7dc19b03140776b7751415f616 Mon Sep 17 00:00:00 2001 From: Siddhesh Deshpande Date: Sat, 3 Jan 2026 10:36:35 +0100 Subject: [PATCH 9/9] arf testing --- .../crud/test_crud_sync_connection_uow.py | 234 ++++++++++++++++++ .../unit/platform/sync/test_cursor_loading.py | 204 +++++++++++++++ .../platform/sync/test_execution_config.py | 204 +++++++++++++++ 3 files changed, 642 insertions(+) create mode 100644 backend/tests/unit/crud/test_crud_sync_connection_uow.py create mode 100644 backend/tests/unit/platform/sync/test_cursor_loading.py create mode 100644 backend/tests/unit/platform/sync/test_execution_config.py diff --git a/backend/tests/unit/crud/test_crud_sync_connection_uow.py b/backend/tests/unit/crud/test_crud_sync_connection_uow.py new file mode 100644 index 000000000..5d7491087 --- /dev/null +++ b/backend/tests/unit/crud/test_crud_sync_connection_uow.py @@ -0,0 +1,234 @@ +"""Tests for CRUD sync_connection UnitOfWork handling.""" + +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest + +from airweave.crud.crud_sync_connection import CRUDSyncConnection +from airweave.db.unit_of_work import UnitOfWork +from airweave.models.sync_connection import DestinationRole + + +class TestCRUDSyncConnectionUnitOfWork: + """Test that CRUD operations properly handle UnitOfWork.""" + + @pytest.fixture + def crud(self): + """Create CRUD instance.""" + return CRUDSyncConnection() + + @pytest.fixture + def mock_db(self): + """Create mock database session.""" + db = AsyncMock() + db.commit = AsyncMock() + db.refresh = AsyncMock() + db.execute = AsyncMock() + db.add = MagicMock() + db.delete = AsyncMock() + return db + + @pytest.fixture + def mock_uow(self): + """Create mock UnitOfWork.""" + uow = MagicMock(spec=UnitOfWork) + # Important: UoW does NOT have flush() method + assert not hasattr(uow, "flush") + return uow + + @pytest.mark.asyncio + async def test_create_without_uow_commits(self, crud, mock_db): + """Test create() without UoW commits immediately.""" + sync_id = uuid4() + connection_id = uuid4() + + await crud.create( + db=mock_db, + sync_id=sync_id, + connection_id=connection_id, + role=DestinationRole.ACTIVE, + uow=None, + ) + + # Should commit and refresh + mock_db.commit.assert_called_once() + mock_db.refresh.assert_called_once() + + @pytest.mark.asyncio + async def test_create_with_uow_no_commit(self, crud, mock_db, mock_uow): + """Test create() with UoW does NOT commit (UoW handles it).""" + sync_id = uuid4() + connection_id = uuid4() + + await crud.create( + db=mock_db, + sync_id=sync_id, + connection_id=connection_id, + role=DestinationRole.ACTIVE, + uow=mock_uow, + ) + + # Should NOT commit or refresh (UoW handles on __aexit__) + mock_db.commit.assert_not_called() + mock_db.refresh.assert_not_called() + + @pytest.mark.asyncio + async def test_update_role_without_uow_commits(self, crud, mock_db): + """Test update_role() without UoW commits immediately.""" + slot_id = uuid4() + + # Mock get() to return None (simplified) + crud.get = AsyncMock(return_value=None) + + await crud.update_role( + db=mock_db, + id=slot_id, + role=DestinationRole.SHADOW, + uow=None, + ) + + # Should commit + mock_db.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_update_role_with_uow_no_commit(self, crud, mock_db, mock_uow): + """Test update_role() with UoW does NOT commit.""" + slot_id = uuid4() + + crud.get = AsyncMock(return_value=None) + + await crud.update_role( + db=mock_db, + id=slot_id, + role=DestinationRole.SHADOW, + uow=mock_uow, + ) + + # Should NOT commit + mock_db.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_bulk_update_role_without_uow_commits(self, crud, mock_db): + """Test bulk_update_role() without UoW commits immediately.""" + sync_id = uuid4() + mock_result = MagicMock(rowcount=5) + mock_db.execute.return_value = mock_result + + rowcount = await crud.bulk_update_role( + db=mock_db, + sync_id=sync_id, + from_role=DestinationRole.ACTIVE, + to_role=DestinationRole.DEPRECATED, + uow=None, + ) + + # Should commit + mock_db.commit.assert_called_once() + assert rowcount == 5 + + @pytest.mark.asyncio + async def test_bulk_update_role_with_uow_no_commit(self, crud, mock_db, mock_uow): + """Test bulk_update_role() with UoW does NOT commit.""" + sync_id = uuid4() + mock_result = MagicMock(rowcount=5) + mock_db.execute.return_value = mock_result + + rowcount = await crud.bulk_update_role( + db=mock_db, + sync_id=sync_id, + from_role=DestinationRole.ACTIVE, + to_role=DestinationRole.DEPRECATED, + uow=mock_uow, + ) + + # Should NOT commit + mock_db.commit.assert_not_called() + assert rowcount == 5 + + @pytest.mark.asyncio + async def test_remove_without_uow_commits(self, crud, mock_db): + """Test remove() without UoW commits immediately.""" + slot_id = uuid4() + mock_obj = MagicMock() + crud.get = AsyncMock(return_value=mock_obj) + + result = await crud.remove( + db=mock_db, + id=slot_id, + uow=None, + ) + + # Should commit + mock_db.commit.assert_called_once() + assert result is True + + @pytest.mark.asyncio + async def test_remove_with_uow_no_commit(self, crud, mock_db, mock_uow): + """Test remove() with UoW does NOT commit.""" + slot_id = uuid4() + mock_obj = MagicMock() + crud.get = AsyncMock(return_value=mock_obj) + + result = await crud.remove( + db=mock_db, + id=slot_id, + uow=mock_uow, + ) + + # Should NOT commit + mock_db.commit.assert_not_called() + assert result is True + + @pytest.mark.asyncio + async def test_uow_does_not_have_flush_method(self, mock_uow): + """Test that UnitOfWork does NOT have flush() method (regression test).""" + # This test documents the bug we fixed: code was calling uow.flush() + # which doesn't exist on UnitOfWork + assert not hasattr(mock_uow, "flush") + + # Attempting to call flush() would raise AttributeError + with pytest.raises(AttributeError): + mock_uow.flush() + + +class TestCRUDSyncConnectionTransactionBehavior: + """Integration-style tests for transaction behavior.""" + + @pytest.fixture + def crud(self): + return CRUDSyncConnection() + + @pytest.mark.asyncio + async def test_multiple_operations_with_uow_single_transaction(self, crud): + """Test that multiple operations with same UoW use single transaction.""" + mock_db = AsyncMock() + mock_uow = MagicMock(spec=UnitOfWork) + + sync_id = uuid4() + conn_id1 = uuid4() + conn_id2 = uuid4() + + # Create multiple connections in same UoW + await crud.create(mock_db, sync_id=sync_id, connection_id=conn_id1, uow=mock_uow) + await crud.create(mock_db, sync_id=sync_id, connection_id=conn_id2, uow=mock_uow) + + # Neither should commit (UoW commits once at end) + assert mock_db.commit.call_count == 0 + + @pytest.mark.asyncio + async def test_operations_without_uow_multiple_commits(self, crud): + """Test that operations without UoW commit individually.""" + mock_db = AsyncMock() + + sync_id = uuid4() + conn_id1 = uuid4() + conn_id2 = uuid4() + + # Create multiple connections without UoW + await crud.create(mock_db, sync_id=sync_id, connection_id=conn_id1, uow=None) + await crud.create(mock_db, sync_id=sync_id, connection_id=conn_id2, uow=None) + + # Each should commit separately + assert mock_db.commit.call_count == 2 + diff --git a/backend/tests/unit/platform/sync/test_cursor_loading.py b/backend/tests/unit/platform/sync/test_cursor_loading.py new file mode 100644 index 000000000..2db385be8 --- /dev/null +++ b/backend/tests/unit/platform/sync/test_cursor_loading.py @@ -0,0 +1,204 @@ +"""Tests for cursor loading with skip_cursor_load flag.""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from airweave.platform.sync.config import SyncExecutionConfig +from airweave.platform.sync.factory._context import ContextBuilder + + +class TestCursorLoadingWithExecutionConfig: + """Test cursor loading respects skip_cursor_load flag.""" + + @pytest.mark.asyncio + async def test_skip_cursor_load_true_skips_loading(self): + """Test that skip_cursor_load=True prevents cursor data from being loaded.""" + # Setup + sync_id = uuid4() + mock_sync = MagicMock(id=sync_id) + source_connection_data = { + "source_class": MagicMock(_cursor_class=None), + } + + # Execution config with skip_cursor_load + execution_config = SyncExecutionConfig(skip_cursor_load=True) + + # Mock database and context + mock_db = AsyncMock() + mock_ctx = MagicMock() + mock_logger = MagicMock() + + builder = ContextBuilder(db=mock_db, ctx=mock_ctx, logger=mock_logger) + + # Mock sync_cursor_service to ensure it's NOT called + with patch("airweave.platform.sync.factory._context.sync_cursor_service") as mock_service: + mock_service.get_cursor_data = AsyncMock() + + cursor = await builder._create_cursor( + sync=mock_sync, + source_connection_data=source_connection_data, + force_full_sync=False, + execution_config=execution_config, + ) + + # Cursor data should be None (not loaded) + assert cursor.cursor_data is None + + # Service should NOT have been called + mock_service.get_cursor_data.assert_not_called() + + # Should log skip cursor load message + log_calls = [str(call) for call in mock_logger.info.call_args_list] + assert any("SKIP CURSOR LOAD" in call for call in log_calls) + + @pytest.mark.asyncio + async def test_skip_cursor_load_false_loads_cursor(self): + """Test that skip_cursor_load=False loads cursor data normally.""" + # Setup + sync_id = uuid4() + mock_sync = MagicMock(id=sync_id) + source_connection_data = { + "source_class": MagicMock(_cursor_class=None), + } + + cursor_data = {"last_sync": "2024-01-01"} + + # Execution config WITHOUT skip_cursor_load + execution_config = SyncExecutionConfig(skip_cursor_load=False) + + # Mock database and context + mock_db = AsyncMock() + mock_ctx = MagicMock() + mock_logger = MagicMock() + + builder = ContextBuilder(db=mock_db, ctx=mock_ctx, logger=mock_logger) + + # Mock sync_cursor_service to return cursor data + with patch("airweave.platform.sync.factory._context.sync_cursor_service") as mock_service: + mock_service.get_cursor_data = AsyncMock(return_value=cursor_data) + + cursor = await builder._create_cursor( + sync=mock_sync, + source_connection_data=source_connection_data, + force_full_sync=False, + execution_config=execution_config, + ) + + # Cursor data should be loaded + assert cursor.cursor_data == cursor_data + + # Service should have been called + mock_service.get_cursor_data.assert_called_once() + + # Should log incremental sync message + log_calls = [str(call) for call in mock_logger.info.call_args_list] + assert any("Incremental sync" in call for call in log_calls) + + @pytest.mark.asyncio + async def test_force_full_sync_overrides_skip_cursor_load(self): + """Test that force_full_sync=True takes precedence over skip_cursor_load.""" + # Setup + sync_id = uuid4() + mock_sync = MagicMock(id=sync_id) + source_connection_data = { + "source_class": MagicMock(_cursor_class=None), + } + + # Execution config with skip_cursor_load, but force_full_sync should override + execution_config = SyncExecutionConfig(skip_cursor_load=False) + + mock_db = AsyncMock() + mock_ctx = MagicMock() + mock_logger = MagicMock() + + builder = ContextBuilder(db=mock_db, ctx=mock_ctx, logger=mock_logger) + + with patch("airweave.platform.sync.factory._context.sync_cursor_service") as mock_service: + mock_service.get_cursor_data = AsyncMock() + + cursor = await builder._create_cursor( + sync=mock_sync, + source_connection_data=source_connection_data, + force_full_sync=True, # This should override + execution_config=execution_config, + ) + + # Cursor data should be None (not loaded) + assert cursor.cursor_data is None + + # Service should NOT have been called + mock_service.get_cursor_data.assert_not_called() + + # Should log FORCE FULL SYNC message (not skip cursor load) + log_calls = [str(call) for call in mock_logger.info.call_args_list] + assert any("FORCE FULL SYNC" in call for call in log_calls) + + @pytest.mark.asyncio + async def test_no_execution_config_loads_cursor_normally(self): + """Test that None execution_config behaves like default (loads cursor).""" + # Setup + sync_id = uuid4() + mock_sync = MagicMock(id=sync_id) + source_connection_data = { + "source_class": MagicMock(_cursor_class=None), + } + + cursor_data = {"last_sync": "2024-01-01"} + + mock_db = AsyncMock() + mock_ctx = MagicMock() + mock_logger = MagicMock() + + builder = ContextBuilder(db=mock_db, ctx=mock_ctx, logger=mock_logger) + + with patch("airweave.platform.sync.factory._context.sync_cursor_service") as mock_service: + mock_service.get_cursor_data = AsyncMock(return_value=cursor_data) + + cursor = await builder._create_cursor( + sync=mock_sync, + source_connection_data=source_connection_data, + force_full_sync=False, + execution_config=None, # No config + ) + + # Cursor data should be loaded + assert cursor.cursor_data == cursor_data + + # Service should have been called + mock_service.get_cursor_data.assert_called_once() + + @pytest.mark.asyncio + async def test_arf_capture_only_preset_skips_cursor(self): + """Test that arf_capture_only preset properly skips cursor loading.""" + # Setup + sync_id = uuid4() + mock_sync = MagicMock(id=sync_id) + source_connection_data = { + "source_class": MagicMock(_cursor_class=None), + } + + # Use the preset + execution_config = SyncExecutionConfig.arf_capture_only() + + mock_db = AsyncMock() + mock_ctx = MagicMock() + mock_logger = MagicMock() + + builder = ContextBuilder(db=mock_db, ctx=mock_ctx, logger=mock_logger) + + with patch("airweave.platform.sync.factory._context.sync_cursor_service") as mock_service: + mock_service.get_cursor_data = AsyncMock() + + cursor = await builder._create_cursor( + sync=mock_sync, + source_connection_data=source_connection_data, + force_full_sync=False, + execution_config=execution_config, + ) + + # Should NOT load cursor + assert cursor.cursor_data is None + mock_service.get_cursor_data.assert_not_called() + diff --git a/backend/tests/unit/platform/sync/test_execution_config.py b/backend/tests/unit/platform/sync/test_execution_config.py new file mode 100644 index 000000000..5704a4235 --- /dev/null +++ b/backend/tests/unit/platform/sync/test_execution_config.py @@ -0,0 +1,204 @@ +"""Tests for SyncExecutionConfig validation and presets.""" + +import warnings +from uuid import uuid4 + +import pytest + +from airweave.platform.sync.config import SyncExecutionConfig + + +class TestSyncExecutionConfigValidation: + """Test SyncExecutionConfig validators.""" + + def test_invalid_destination_strategy_raises_error(self): + """Test that invalid destination_strategy raises ValueError.""" + with pytest.raises(ValueError, match="destination_strategy must be one of"): + SyncExecutionConfig(destination_strategy="invalid_strategy") + + def test_valid_destination_strategies(self): + """Test all valid destination strategies are accepted.""" + valid_strategies = ["active_only", "shadow_only", "all", "active_and_shadow"] + for strategy in valid_strategies: + config = SyncExecutionConfig(destination_strategy=strategy) + assert config.destination_strategy == strategy + + def test_target_destinations_overrides_strategy_warning(self): + """Test warning when target_destinations overrides destination_strategy.""" + dest_id = uuid4() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + SyncExecutionConfig( + target_destinations=[dest_id], + destination_strategy="shadow_only", # Will be ignored + ) + assert len(w) == 1 + assert "ignored when target_destinations is set" in str(w[0].message) + + def test_no_warning_for_default_strategy_with_targets(self): + """Test no warning when strategy is default (active_and_shadow).""" + dest_id = uuid4() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + SyncExecutionConfig( + target_destinations=[dest_id], + destination_strategy="active_and_shadow", # Default + ) + # Should only warn about ARF re-writes, not strategy override + assert all("destination_strategy" not in str(warning.message) for warning in w) + + def test_destination_conflict_raises_error(self): + """Test error when same destination in target and exclude.""" + dest_id = uuid4() + with pytest.raises(ValueError, match="Cannot have same destination"): + SyncExecutionConfig( + target_destinations=[dest_id], + exclude_destinations=[dest_id], + ) + + def test_no_error_for_different_destinations(self): + """Test no error when target and exclude are different.""" + dest1 = uuid4() + dest2 = uuid4() + config = SyncExecutionConfig( + target_destinations=[dest1], + exclude_destinations=[dest2], + ) + assert dest1 in config.target_destinations + assert dest2 in config.exclude_destinations + + def test_replay_with_arf_handler_warning(self): + """Test warning when replaying but ARF handler is enabled.""" + dest_id = uuid4() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + SyncExecutionConfig( + target_destinations=[dest_id], + enable_raw_data_handler=True, # Usually disabled for replay + ) + assert any("writing the same data to ARF again" in str(warning.message) for warning in w) + + def test_skip_cursor_updates_without_hash_warning(self): + """Test warning about unusual cursor skip combination.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + SyncExecutionConfig( + skip_cursor_updates=True, + skip_hash_updates=False, # Unusual combo + ) + assert any("skip_cursor_updates=True but skip_hash_updates=False" in str(warning.message) for warning in w) + + +class TestSyncExecutionConfigPresets: + """Test SyncExecutionConfig preset configurations.""" + + def test_default_preset(self): + """Test default() preset returns expected configuration.""" + config = SyncExecutionConfig.default() + + assert config.destination_strategy == "active_and_shadow" + assert config.enable_vector_handlers is True + assert config.enable_raw_data_handler is True + assert config.enable_postgres_handler is True + assert config.skip_hash_comparison is False + assert config.skip_hash_updates is False + assert config.skip_cursor_load is False + assert config.skip_cursor_updates is False + assert config.max_workers == 20 + assert config.batch_size == 100 + + def test_arf_capture_only_preset(self): + """Test arf_capture_only() disables handlers and skips cursor.""" + config = SyncExecutionConfig.arf_capture_only() + + # Handlers disabled except ARF + assert config.enable_vector_handlers is False + assert config.enable_raw_data_handler is True + assert config.enable_postgres_handler is False + + # Cursor and hash skipped + assert config.skip_hash_updates is True + assert config.skip_cursor_load is True + assert config.skip_cursor_updates is True + + def test_replay_to_destination_preset(self): + """Test replay_to_destination() targets specific destination.""" + dest_id = uuid4() + config = SyncExecutionConfig.replay_to_destination(dest_id) + + # Targets specific destination + assert config.target_destinations == [dest_id] + + # ARF handler disabled (reading from ARF, not writing) + assert config.enable_raw_data_handler is False + + # Force inserts (no hash comparison) + assert config.skip_hash_comparison is True + + # Vector and postgres handlers enabled + assert config.enable_vector_handlers is True + assert config.enable_postgres_handler is True + + def test_dry_run_preset(self): + """Test dry_run() disables all handlers.""" + config = SyncExecutionConfig.dry_run() + + # All handlers disabled + assert config.enable_vector_handlers is False + assert config.enable_raw_data_handler is False + assert config.enable_postgres_handler is False + + # Skips hash and cursor + assert config.skip_hash_updates is True + assert config.skip_cursor_load is True + assert config.skip_cursor_updates is True + + +class TestSyncExecutionConfigSerialization: + """Test SyncExecutionConfig serialization for database storage.""" + + def test_config_to_dict(self): + """Test config can be serialized to dict.""" + dest_id = uuid4() + config = SyncExecutionConfig( + target_destinations=[dest_id], + max_workers=50, + ) + + config_dict = config.model_dump() + assert isinstance(config_dict, dict) + assert config_dict["max_workers"] == 50 + assert dest_id in config_dict["target_destinations"] + + def test_config_from_dict(self): + """Test config can be deserialized from dict.""" + dest_id = uuid4() + config_dict = { + "target_destinations": [str(dest_id)], + "destination_strategy": "active_and_shadow", + "enable_vector_handlers": False, + "enable_raw_data_handler": True, + "enable_postgres_handler": True, + "skip_hash_comparison": False, + "skip_hash_updates": True, + "skip_cursor_load": True, + "skip_cursor_updates": True, + "max_workers": 30, + "batch_size": 200, + } + + config = SyncExecutionConfig(**config_dict) + assert config.max_workers == 30 + assert config.batch_size == 200 + assert config.skip_cursor_load is True + + def test_preset_roundtrip(self): + """Test preset can be serialized and deserialized.""" + original = SyncExecutionConfig.arf_capture_only() + config_dict = original.model_dump() + restored = SyncExecutionConfig(**config_dict) + + assert original.enable_vector_handlers == restored.enable_vector_handlers + assert original.skip_cursor_load == restored.skip_cursor_load + assert original.skip_cursor_updates == restored.skip_cursor_updates +