diff --git a/backend/infrahub/core/attribute.py b/backend/infrahub/core/attribute.py index 369350612f..9cffeba66a 100644 --- a/backend/infrahub/core/attribute.py +++ b/backend/infrahub/core/attribute.py @@ -578,7 +578,12 @@ def _filter_sensitive(self, value: str, filter_sensitive: bool) -> str: return value - async def from_graphql(self, data: dict, db: InfrahubDatabase) -> bool: + async def from_graphql( + self, + data: dict, + db: InfrahubDatabase, + process_pools: bool = True, + ) -> bool: """Update attr from GraphQL payload""" changed = False @@ -592,7 +597,8 @@ async def from_graphql(self, data: dict, db: InfrahubDatabase) -> bool: changed = True elif "from_pool" in data: self.from_pool = data["from_pool"] - await self.node.handle_pool(db=db, attribute=self, errors=[]) + if process_pools: + await self.node.handle_pool(db=db, attribute=self, errors=[]) changed = True if changed and self.is_from_profile: diff --git a/backend/infrahub/core/node/__init__.py b/backend/infrahub/core/node/__init__.py index a173f5e9bd..4645cf1af8 100644 --- a/backend/infrahub/core/node/__init__.py +++ b/backend/infrahub/core/node/__init__.py @@ -257,7 +257,13 @@ async def init( return cls(**attrs) - async def handle_pool(self, db: InfrahubDatabase, attribute: BaseAttribute, errors: list) -> None: + async def handle_pool( + self, + db: InfrahubDatabase, + attribute: BaseAttribute, + errors: list, + allocate_resources: bool = True, + ) -> None: """Evaluate if a resource has been requested from a pool and apply the resource This method only works on number pools, currently Integer is the only type that has the from_pool @@ -268,7 +274,7 @@ async def handle_pool(self, db: InfrahubDatabase, attribute: BaseAttribute, erro attribute.from_pool = {"id": attribute.schema.parameters.number_pool_id} attribute.is_default = False - if not attribute.from_pool: + if not attribute.from_pool or not allocate_resources: return try: @@ -426,7 +432,12 @@ async def handle_object_template(self, fields: dict, db: InfrahubDatabase, error elif relationship_peers := await relationship.get_peers(db=db): fields[relationship_name] = [{"id": peer_id} for peer_id in relationship_peers] - async def _process_fields(self, fields: dict, db: InfrahubDatabase) -> None: + async def _process_fields( + self, + fields: dict, + db: InfrahubDatabase, + process_pools: bool = True, + ) -> None: errors = [] if "_source" in fields.keys(): @@ -480,7 +491,7 @@ async def _process_fields(self, fields: dict, db: InfrahubDatabase) -> None: # Generate Attribute and Relationship and assign them # ------------------------------------------- errors.extend(await self._process_fields_relationships(fields=fields, db=db)) - errors.extend(await self._process_fields_attributes(fields=fields, db=db)) + errors.extend(await self._process_fields_attributes(fields=fields, db=db, process_pools=process_pools)) if errors: raise ValidationError(errors) @@ -517,7 +528,12 @@ async def _process_fields_relationships(self, fields: dict, db: InfrahubDatabase return errors - async def _process_fields_attributes(self, fields: dict, db: InfrahubDatabase) -> list[ValidationError]: + async def _process_fields_attributes( + self, + fields: dict, + db: InfrahubDatabase, + process_pools: bool, + ) -> list[ValidationError]: errors: list[ValidationError] = [] for attr_schema in self._schema.attributes: @@ -542,9 +558,15 @@ async def _process_fields_attributes(self, fields: dict, db: InfrahubDatabase) - ) if not self._existing: attribute: BaseAttribute = getattr(self, attr_schema.name) - await self.handle_pool(db=db, attribute=attribute, errors=errors) + await self.handle_pool( + db=db, + attribute=attribute, + errors=errors, + allocate_resources=process_pools, + ) - attribute.validate(value=attribute.value, name=attribute.name, schema=attribute.schema) + if process_pools or attribute.from_pool is None: + attribute.validate(value=attribute.value, name=attribute.name, schema=attribute.schema) except ValidationError as exc: errors.append(exc) @@ -672,7 +694,13 @@ async def process_label(self, db: InfrahubDatabase | None = None) -> None: # no self.label.value = " ".join([word.title() for word in self.name.value.split("_")]) self.label.is_default = False - async def new(self, db: InfrahubDatabase, id: str | None = None, **kwargs: Any) -> Self: + async def new( + self, + db: InfrahubDatabase, + id: str | None = None, + process_pools: bool = True, + **kwargs: Any, + ) -> Self: if id and not is_valid_uuid(id): raise ValidationError({"id": f"{id} is not a valid UUID"}) if id: @@ -682,7 +710,7 @@ async def new(self, db: InfrahubDatabase, id: str | None = None, **kwargs: Any) self.id = id or str(UUIDT()) - await self._process_fields(db=db, fields=kwargs) + await self._process_fields(db=db, fields=kwargs, process_pools=process_pools) await self._process_macros(db=db) return self @@ -935,15 +963,29 @@ async def to_graphql( return response - async def from_graphql(self, data: dict, db: InfrahubDatabase) -> bool: - """Update object from a GraphQL payload.""" + async def from_graphql( + self, + data: dict, + db: InfrahubDatabase, + process_pools: bool = True, + ) -> bool: + """Update object from a GraphQL payload. + + Args: + data: GraphQL payload to apply. + db: Database connection used for related lookups. + process_pools: Whether resource pool allocations should be performed. + + Returns: + True if any field changed, otherwise False. + """ changed = False for key, value in data.items(): if key in self._attributes and isinstance(value, dict): attribute = getattr(self, key) - changed |= await attribute.from_graphql(data=value, db=db) + changed |= await attribute.from_graphql(data=value, db=db, process_pools=process_pools) if key in self._relationships: rel: RelationshipManager = getattr(self, key) diff --git a/backend/infrahub/core/node/create.py b/backend/infrahub/core/node/create.py index 6cb827ceda..a39b2f9669 100644 --- a/backend/infrahub/core/node/create.py +++ b/backend/infrahub/core/node/create.py @@ -1,14 +1,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Mapping +from typing import TYPE_CHECKING, Any, Mapping, Sequence +from infrahub import lock from infrahub.core import registry from infrahub.core.constants import RelationshipCardinality, RelationshipKind from infrahub.core.constraint.node.runner import NodeConstraintRunner from infrahub.core.manager import NodeManager from infrahub.core.node import Node +from infrahub.core.node.save import run_constraints_and_save from infrahub.core.protocols import CoreObjectTemplate from infrahub.dependencies.registry import get_component_registry +from infrahub.lock import InfrahubMultiLock +from infrahub.lock_getter import get_lock_names_on_object_mutation if TYPE_CHECKING: from infrahub.core.branch import Branch @@ -141,18 +145,23 @@ async def refresh_for_profile_update( async def _do_create_node( - node_class: type[Node], + obj: Node, db: InfrahubDatabase, - data: dict, - schema: NonGenericSchemaTypes, - fields_to_validate: list, branch: Branch, + fields_to_validate: list[str], node_constraint_runner: NodeConstraintRunner, + lock_names: Sequence[str], ) -> Node: - obj = await node_class.init(db=db, schema=schema, branch=branch) - await obj.new(db=db, **data) - await node_constraint_runner.check(node=obj, field_filters=fields_to_validate) - await obj.save(db=db) + await run_constraints_and_save( + node=obj, + node_constraint_runner=node_constraint_runner, + fields_to_validate=fields_to_validate, + fields_to_save=[], + db=db, + branch=branch, + lock_names=lock_names, + manage_lock=False, + ) object_template = await obj.get_object_template(db=db) if object_template: @@ -162,6 +171,7 @@ async def _do_create_node( template=object_template, obj=obj, fields=fields_to_validate, + constraint_runner=node_constraint_runner, ) return obj @@ -175,35 +185,39 @@ async def create_node( """Create a node in the database if constraint checks succeed.""" component_registry = get_component_registry() - node_constraint_runner = await component_registry.get_component( - NodeConstraintRunner, db=db.start_session() if not db.is_transaction else db, branch=branch - ) node_class = Node if schema.kind in registry.node: node_class = registry.node[schema.kind] fields_to_validate = list(data) - if db.is_transaction: - obj = await _do_create_node( - node_class=node_class, - node_constraint_runner=node_constraint_runner, - db=db, - schema=schema, + preview_obj = await node_class.init(db=db, schema=schema, branch=branch) + await preview_obj.new(db=db, process_pools=False, **data) + schema_branch = db.schema.get_schema_branch(name=branch.name) + lock_names = get_lock_names_on_object_mutation(node=preview_obj, schema_branch=schema_branch) + + async def _persist(current_db: InfrahubDatabase) -> Node: + node_constraint_runner = await component_registry.get_component( + NodeConstraintRunner, db=current_db, branch=branch + ) + obj = await node_class.init(db=current_db, schema=schema, branch=branch) + await obj.new(db=current_db, **data) + + return await _do_create_node( + obj=obj, + db=current_db, branch=branch, fields_to_validate=fields_to_validate, - data=data, + node_constraint_runner=node_constraint_runner, + lock_names=lock_names, ) - else: - async with db.start_transaction() as dbt: - obj = await _do_create_node( - node_class=node_class, - node_constraint_runner=node_constraint_runner, - db=dbt, - schema=schema, - branch=branch, - fields_to_validate=fields_to_validate, - data=data, - ) + + obj: Node + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + if db.is_transaction: + obj = await _persist(db) + else: + async with db.start_transaction() as dbt: + obj = await _persist(dbt) if await get_profile_ids(db=db, obj=obj): obj = await refresh_for_profile_update(db=db, branch=branch, schema=schema, obj=obj) diff --git a/backend/infrahub/core/node/save.py b/backend/infrahub/core/node/save.py new file mode 100644 index 0000000000..afca488399 --- /dev/null +++ b/backend/infrahub/core/node/save.py @@ -0,0 +1,57 @@ +from collections.abc import Sequence + +from infrahub import lock +from infrahub.core.branch import Branch +from infrahub.core.constraint.node.runner import NodeConstraintRunner +from infrahub.core.node import Node +from infrahub.database import InfrahubDatabase +from infrahub.lock import InfrahubMultiLock +from infrahub.lock_getter import get_lock_names_on_object_mutation + + +async def run_constraints_and_save( + node: Node, + node_constraint_runner: NodeConstraintRunner, + fields_to_validate: list[str], + fields_to_save: list[str], + db: InfrahubDatabase, + branch: Branch, + skip_uniqueness_check: bool = False, + lock_names: Sequence[str] | None = None, + manage_lock: bool = True, +) -> None: + """Validate a node and persist it, optionally reusing an existing lock context. + + Args: + node: The node instance to validate and persist. + node_constraint_runner: Runner executing node-level constraints. + fields_to_validate: Field names that must be validated. + fields_to_save: Field names that must be persisted. + db: Database connection or transaction to use for persistence. + branch: Branch associated with the mutation. + skip_uniqueness_check: Whether to skip uniqueness constraints. + lock_names: Precomputed lock identifiers to reuse when ``manage_lock`` is False. + manage_lock: Whether this helper should acquire and release locks itself. + """ + + if not manage_lock and lock_names is None: + raise ValueError("lock_names must be provided when manage_lock is False") + + schema_branch = db.schema.get_schema_branch(name=branch.name) + locks = ( + list(lock_names) + if lock_names is not None + else get_lock_names_on_object_mutation(node=node, schema_branch=schema_branch) + ) + + async def _persist() -> None: + await node_constraint_runner.check( + node=node, field_filters=fields_to_validate, skip_uniqueness_check=skip_uniqueness_check + ) + await node.save(db=db, fields=fields_to_save) + + if manage_lock: + async with InfrahubMultiLock(lock_registry=lock.registry, locks=locks): + await _persist() + else: + await _persist() diff --git a/backend/infrahub/graphql/mutations/ipam.py b/backend/infrahub/graphql/mutations/ipam.py index 0522a0d9ec..55ec84bfae 100644 --- a/backend/infrahub/graphql/mutations/ipam.py +++ b/backend/infrahub/graphql/mutations/ipam.py @@ -1,5 +1,4 @@ import ipaddress -from ipaddress import IPv4Interface from typing import TYPE_CHECKING, Any from graphene import InputObjectType, Mutation @@ -13,12 +12,14 @@ from infrahub.core.ipam.reconciler import IpamReconciler from infrahub.core.manager import NodeManager from infrahub.core.node import Node +from infrahub.core.node.create import get_profile_ids from infrahub.core.schema import NodeSchema from infrahub.database import InfrahubDatabase, retry_db_transaction from infrahub.exceptions import NodeNotFoundError, ValidationError from infrahub.lock import InfrahubMultiLock, build_object_lock_name from infrahub.log import get_logger +from ...lock_getter import get_lock_names_on_object_mutation from .main import DeleteResult, InfrahubMutationMixin, InfrahubMutationOptions from .node_getter.by_default_filter import MutationNodeGetterByDefaultFilter @@ -106,11 +107,8 @@ def __init_subclass_with_meta__( super().__init_subclass_with_meta__(_meta=_meta, **options) @staticmethod - def _get_lock_name(namespace_id: str, branch: Branch) -> str | None: - if not branch.is_default: - # Do not lock on other branches as reconciliation will be performed at least when merging in main branch. - return None - return build_object_lock_name(InfrahubKind.IPADDRESS + "_" + namespace_id) + def _get_lock_names(namespace_id: str) -> list[str]: + return [build_object_lock_name(InfrahubKind.IPADDRESS + "_" + namespace_id)] @classmethod async def _mutate_create_object_and_reconcile( @@ -118,7 +116,7 @@ async def _mutate_create_object_and_reconcile( data: InputObjectType, branch: Branch, db: InfrahubDatabase, - ip_address: IPv4Interface | ipaddress.IPv6Interface, + ip_address: ipaddress.IPv4Interface | ipaddress.IPv6Interface, namespace_id: str, ) -> Node: address = await cls.mutate_create_object(data=data, db=db, branch=branch) @@ -126,6 +124,7 @@ async def _mutate_create_object_and_reconcile( reconciled_address = await reconciler.reconcile( ip_value=ip_address, namespace=namespace_id, node_uuid=address.get_id() ) + return reconciled_address @classmethod @@ -142,17 +141,13 @@ async def mutate_create( ip_address = ipaddress.ip_interface(data["address"]["value"]) namespace_id = await validate_namespace(db=db, branch=branch, data=data) - async with db.start_transaction() as dbt: - if lock_name := cls._get_lock_name(namespace_id, branch): - async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): - reconciled_address = await cls._mutate_create_object_and_reconcile( - data=data, branch=branch, db=dbt, ip_address=ip_address, namespace_id=namespace_id - ) - else: + lock_names = cls._get_lock_names(namespace_id) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + async with db.start_transaction() as dbt: reconciled_address = await cls._mutate_create_object_and_reconcile( data=data, branch=branch, db=dbt, ip_address=ip_address, namespace_id=namespace_id ) - result = await cls.mutate_create_to_graphql(info=info, db=dbt, obj=reconciled_address) + result = await cls.mutate_create_to_graphql(info=info, db=dbt, obj=reconciled_address) return reconciled_address, result @@ -165,8 +160,24 @@ async def _mutate_update_object_and_reconcile( db: InfrahubDatabase, address: Node, namespace_id: str, + fields_to_validate: list[str], + fields: list[str], + previous_profile_ids: set[str], + lock_names: list[str], ) -> Node: - address = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=address) + address = await cls.mutate_update_object( + db=db, + info=info, + data=data, + branch=branch, + obj=address, + fields_to_validate=fields_to_validate, + fields=fields, + previous_profile_ids=previous_profile_ids, + lock_names=lock_names, + manage_lock=False, + apply_data=False, + ) reconciler = IpamReconciler(db=db, branch=branch) ip_address = ipaddress.ip_interface(address.address.value) reconciled_address = await reconciler.reconcile( @@ -198,18 +209,35 @@ async def mutate_update( namespace = await address.ip_namespace.get_peer(db) namespace_id = await validate_namespace(db=db, branch=branch, data=data, existing_namespace_id=namespace.id) - async with db.start_transaction() as dbt: - if lock_name := cls._get_lock_name(namespace_id, branch): - async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): + before_mutate_profile_ids = await get_profile_ids(db=db, obj=address) + await address.from_graphql(db=db, data=data) + fields_to_validate = list(data) + fields = list(data.keys()) + + for field_to_remove in ("id", "hfid"): + if field_to_remove in fields: + fields.remove(field_to_remove) + + schema_branch = db.schema.get_schema_branch(name=branch.name) + lock_names = get_lock_names_on_object_mutation(node=address, schema_branch=schema_branch) + + namespace_lock_names = cls._get_lock_names(namespace_id) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=namespace_lock_names): + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + async with db.start_transaction() as dbt: reconciled_address = await cls._mutate_update_object_and_reconcile( - info=info, data=data, branch=branch, address=address, namespace_id=namespace_id, db=dbt + info=info, + data=data, + branch=branch, + db=dbt, + address=address, + namespace_id=namespace_id, + fields_to_validate=fields_to_validate, + fields=fields, + previous_profile_ids=before_mutate_profile_ids, + lock_names=lock_names, ) - else: - reconciled_address = await cls._mutate_update_object_and_reconcile( - info=info, data=data, branch=branch, address=address, namespace_id=namespace_id, db=dbt - ) - - result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_address) + result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_address) return address, result @@ -261,11 +289,11 @@ def __init_subclass_with_meta__( super().__init_subclass_with_meta__(_meta=_meta, **options) @staticmethod - def _get_lock_name(namespace_id: str) -> str | None: + def _get_lock_names(namespace_id: str) -> list[str]: # IPPrefix has some cardinality-one relationships involved (parent/child/ip_address), # so we need to lock on any branch to avoid creating multiple peers for these relationships # during concurrent ipam reconciliations. - return build_object_lock_name(InfrahubKind.IPPREFIX + "_" + namespace_id) + return [build_object_lock_name(InfrahubKind.IPPREFIX + "_" + namespace_id)] @classmethod async def _mutate_create_object_and_reconcile( @@ -293,14 +321,13 @@ async def mutate_create( db = database or graphql_context.db namespace_id = await validate_namespace(db=db, branch=branch, data=data) - async with db.start_transaction() as dbt: - lock_name = cls._get_lock_name(namespace_id) - async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): + lock_names = cls._get_lock_names(namespace_id) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + async with db.start_transaction() as dbt: reconciled_prefix = await cls._mutate_create_object_and_reconcile( data=data, branch=branch, db=dbt, namespace_id=namespace_id ) - - result = await cls.mutate_create_to_graphql(info=info, db=dbt, obj=reconciled_prefix) + result = await cls.mutate_create_to_graphql(info=info, db=dbt, obj=reconciled_prefix) return reconciled_prefix, result @@ -313,8 +340,24 @@ async def _mutate_update_object_and_reconcile( db: InfrahubDatabase, prefix: Node, namespace_id: str, + fields_to_validate: list[str], + fields: list[str], + previous_profile_ids: set[str], + lock_names: list[str], ) -> Node: - prefix = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=prefix) + prefix = await cls.mutate_update_object( + db=db, + info=info, + data=data, + branch=branch, + obj=prefix, + fields_to_validate=fields_to_validate, + fields=fields, + previous_profile_ids=previous_profile_ids, + lock_names=lock_names, + manage_lock=False, + apply_data=False, + ) return await cls._reconcile_prefix( branch=branch, db=db, prefix=prefix, namespace_id=namespace_id, is_delete=False ) @@ -343,13 +386,35 @@ async def mutate_update( namespace = await prefix.ip_namespace.get_peer(db) namespace_id = await validate_namespace(db=db, branch=branch, data=data, existing_namespace_id=namespace.id) - async with db.start_transaction() as dbt: - lock_name = cls._get_lock_name(namespace_id) - async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): - reconciled_prefix = await cls._mutate_update_object_and_reconcile( - info=info, data=data, prefix=prefix, db=dbt, namespace_id=namespace_id, branch=branch - ) - result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_prefix) + before_mutate_profile_ids = await get_profile_ids(db=db, obj=prefix) + await prefix.from_graphql(db=db, data=data) + fields_to_validate = list(data) + fields = list(data.keys()) + + for field_to_remove in ("id", "hfid"): + if field_to_remove in fields: + fields.remove(field_to_remove) + + schema_branch = db.schema.get_schema_branch(name=branch.name) + lock_names = get_lock_names_on_object_mutation(node=prefix, schema_branch=schema_branch) + + namespace_lock_names = cls._get_lock_names(namespace_id) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=namespace_lock_names): + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + async with db.start_transaction() as dbt: + reconciled_prefix = await cls._mutate_update_object_and_reconcile( + info=info, + data=data, + branch=branch, + db=dbt, + prefix=prefix, + namespace_id=namespace_id, + fields_to_validate=fields_to_validate, + fields=fields, + previous_profile_ids=before_mutate_profile_ids, + lock_names=lock_names, + ) + result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_prefix) return prefix, result @@ -408,9 +473,9 @@ async def mutate_delete( namespace_rels = await prefix.ip_namespace.get_relationships(db=db) namespace_id = namespace_rels[0].peer_id - async with graphql_context.db.start_transaction() as dbt: - lock_name = cls._get_lock_name(namespace_id) - async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): + lock_names = cls._get_lock_names(namespace_id) + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + async with graphql_context.db.start_transaction() as dbt: reconciled_prefix = await cls._reconcile_prefix( branch=branch, db=dbt, prefix=prefix, namespace_id=namespace_id, is_delete=True ) diff --git a/backend/infrahub/graphql/mutations/main.py b/backend/infrahub/graphql/mutations/main.py index e3d7e74f00..4d3720130e 100644 --- a/backend/infrahub/graphql/mutations/main.py +++ b/backend/infrahub/graphql/mutations/main.py @@ -1,15 +1,14 @@ from __future__ import annotations -import hashlib from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Sequence from graphene import InputObjectType, Mutation from graphene.types.mutation import MutationOptions from typing_extensions import Self from infrahub import config, lock -from infrahub.core.constants import InfrahubKind, MutationAction +from infrahub.core.constants import MutationAction from infrahub.core.constraint.node.runner import NodeConstraintRunner from infrahub.core.manager import NodeManager from infrahub.core.node.create import ( @@ -28,9 +27,11 @@ from infrahub.exceptions import HFIDViolatedError, InitializationError, NodeNotFoundError from infrahub.graphql.context import apply_external_context from infrahub.graphql.field_extractor import extract_graphql_fields -from infrahub.lock import InfrahubMultiLock, build_object_lock_name from infrahub.log import get_log_data, get_logger +from ...core.node.save import run_constraints_and_save +from ...lock import InfrahubMultiLock +from ...lock_getter import get_lock_names_on_object_mutation from .node_getter.by_default_filter import MutationNodeGetterByDefaultFilter if TYPE_CHECKING: @@ -38,7 +39,6 @@ from infrahub.core.branch import Branch from infrahub.core.node import Node - from infrahub.core.schema.schema_branch import SchemaBranch from infrahub.database import InfrahubDatabase from infrahub.graphql.types.context import ContextInput @@ -47,8 +47,6 @@ log = get_logger() -KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED = [InfrahubKind.GENERICGROUP] - @dataclass class DeleteResult: @@ -153,14 +151,6 @@ async def _call_mutate_create_object( """ Wrapper around mutate_create_object to potentially activate locking. """ - schema_branch = db.schema.get_schema_branch(name=branch.name) - lock_names = _get_kind_lock_names_on_object_mutation( - kind=cls._meta.active_schema.kind, branch=branch, schema_branch=schema_branch, data=data - ) - if lock_names: - async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): - return await cls.mutate_create_object(data=data, db=db, branch=branch, override_data=override_data) - return await cls.mutate_create_object(data=data, db=db, branch=branch, override_data=override_data) @classmethod @@ -220,42 +210,50 @@ async def _call_mutate_update( """ Wrapper around mutate_update to potentially activate locking and call it within a database transaction. """ + before_mutate_profile_ids = await get_profile_ids(db=db, obj=obj) + fields_to_validate = list(data) + fields = list(data.keys()) - schema_branch = db.schema.get_schema_branch(name=branch.name) - lock_names = _get_kind_lock_names_on_object_mutation( - kind=cls._meta.active_schema.kind, branch=branch, schema_branch=schema_branch, data=data + for field_to_remove in ("id", "hfid"): + if field_to_remove in fields: + fields.remove(field_to_remove) + + # Prepare a clone to compute locks without triggering pool allocations + preview_obj = await NodeManager.get_one_by_id_or_default_filter( + db=db, + kind=obj.get_kind(), + id=obj.get_id(), + branch=branch, ) + await preview_obj.from_graphql(db=db, data=data, process_pools=False) - if db.is_transaction: - if lock_names: - async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): - obj = await cls.mutate_update_object( - db=db, info=info, data=data, branch=branch, obj=obj, skip_uniqueness_check=skip_uniqueness_check - ) - else: - obj = await cls.mutate_update_object( - db=db, info=info, data=data, branch=branch, obj=obj, skip_uniqueness_check=skip_uniqueness_check - ) - result = await cls.mutate_update_to_graphql(db=db, info=info, obj=obj) - return obj, result - - async with db.start_transaction() as dbt: - if lock_names: - async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): - obj = await cls.mutate_update_object( - db=dbt, - info=info, - data=data, - branch=branch, - obj=obj, - skip_uniqueness_check=skip_uniqueness_check, - ) - else: - obj = await cls.mutate_update_object( - db=dbt, info=info, data=data, branch=branch, obj=obj, skip_uniqueness_check=skip_uniqueness_check - ) - result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=obj) - return obj, result + schema_branch = db.schema.get_schema_branch(name=branch.name) + lock_names = get_lock_names_on_object_mutation(node=preview_obj, schema_branch=schema_branch) + + async def _mutate(current_db: InfrahubDatabase) -> tuple[Node, Self]: + updated_obj = await cls.mutate_update_object( + db=current_db, + info=info, + data=data, + branch=branch, + obj=obj, + skip_uniqueness_check=skip_uniqueness_check, + fields_to_validate=fields_to_validate, + fields=fields, + previous_profile_ids=before_mutate_profile_ids, + lock_names=lock_names, + manage_lock=False, + apply_data=True, + ) + result = await cls.mutate_update_to_graphql(db=current_db, info=info, obj=updated_obj) + return updated_obj, result + + async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): + if db.is_transaction: + return await _mutate(db) + + async with db.start_transaction() as dbt: + return await _mutate(dbt) @classmethod @retry_db_transaction(name="object_update") @@ -273,6 +271,7 @@ async def mutate_update( obj = node or await NodeManager.find_object( db=db, kind=cls._meta.active_schema.kind, id=data.get("id"), hfid=data.get("hfid"), branch=branch ) + obj, result = await cls._call_mutate_update(info=info, data=data, db=db, branch=branch, obj=obj) return obj, result @@ -286,29 +285,70 @@ async def mutate_update_object( branch: Branch, obj: Node, skip_uniqueness_check: bool = False, + fields_to_validate: list[str] | None = None, + fields: list[str] | None = None, + previous_profile_ids: set[str] | None = None, + lock_names: Sequence[str] | None = None, + manage_lock: bool = True, + apply_data: bool = True, ) -> Node: + """Update an existing node while ensuring constraints and locking semantics. + + Args: + db: Database connection or transaction to use. + info: GraphQL resolver info (unused). + data: GraphQL payload with updated field values. + branch: Active branch for the mutation. + obj: Node being updated. + skip_uniqueness_check: Whether uniqueness constraints should be skipped. + fields_to_validate: Optional list of fields requiring validation. + fields: Optional list of fields to persist on save. + previous_profile_ids: Optional set of profile IDs prior to mutation. + lock_names: Optional precomputed lock identifiers. + manage_lock: Whether this helper should manage lock acquisition. + apply_data: Whether to apply GraphQL data inside this helper. + + Returns: + The updated node instance. + """ + component_registry = get_component_registry() node_constraint_runner = await component_registry.get_component(NodeConstraintRunner, db=db, branch=branch) - before_mutate_profile_ids = await get_profile_ids(db=db, obj=obj) - await obj.from_graphql(db=db, data=data) - fields_to_validate = list(data) - await node_constraint_runner.check( - node=obj, field_filters=fields_to_validate, skip_uniqueness_check=skip_uniqueness_check - ) + profile_ids_before = previous_profile_ids or await get_profile_ids(db=db, obj=obj) - fields = list(data.keys()) - for field_to_remove in ("id", "hfid"): - if field_to_remove in fields: - fields.remove(field_to_remove) + if apply_data: + await obj.from_graphql(db=db, data=data) - await obj.save(db=db, fields=fields) + validation_fields = list(fields_to_validate) if fields_to_validate is not None else list(data) + fields_to_save = list(fields) if fields is not None else list(data.keys()) + + for field_to_remove in ("id", "hfid"): + if field_to_remove in fields_to_save: + fields_to_save.remove(field_to_remove) + + locks = lock_names + if locks is None: + schema_branch = db.schema.get_schema_branch(name=branch.name) + locks = get_lock_names_on_object_mutation(node=obj, schema_branch=schema_branch) + + await run_constraints_and_save( + node=obj, + node_constraint_runner=node_constraint_runner, + fields_to_validate=validation_fields, + fields_to_save=fields_to_save, + db=db, + skip_uniqueness_check=skip_uniqueness_check, + branch=branch, + lock_names=locks, + manage_lock=manage_lock, + ) obj = await refresh_for_profile_update( db=db, branch=branch, obj=obj, - previous_profile_ids=before_mutate_profile_ids, + previous_profile_ids=profile_ids_before, schema=cls._meta.active_schema, ) return obj @@ -471,90 +511,5 @@ def __init_subclass_with_meta__( super().__init_subclass_with_meta__(_meta=_meta, **options) -def _get_kinds_to_lock_on_object_mutation(kind: str, schema_branch: SchemaBranch) -> list[str]: - """ - Return kinds for which we want to lock during creating / updating an object of a given schema node. - Lock should be performed on schema kind and its generics having a uniqueness_constraint defined. - If a generic uniqueness constraint is the same as the node schema one, - it means node schema overrided this constraint, in which case we only need to lock on the generic. - """ - - node_schema = schema_branch.get(name=kind) - - schema_uc = None - kinds = [] - if node_schema.uniqueness_constraints: - kinds.append(node_schema.kind) - schema_uc = node_schema.uniqueness_constraints - - if node_schema.is_generic_schema: - return kinds - - generics_kinds = node_schema.inherit_from - - node_schema_kind_removed = False - for generic_kind in generics_kinds: - generic_uc = schema_branch.get(name=generic_kind).uniqueness_constraints - if generic_uc: - kinds.append(generic_kind) - if not node_schema_kind_removed and generic_uc == schema_uc: - # Check whether we should remove original schema kind as it simply overrides uniqueness_constraint - # of a generic - kinds.pop(0) - node_schema_kind_removed = True - return kinds - - -def _should_kind_be_locked_on_any_branch(kind: str, schema_branch: SchemaBranch) -> bool: - """ - Check whether kind or any kind generic is in KINDS_TO_LOCK_ON_ANY_BRANCH. - """ - - if kind in KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED: - return True - - node_schema = schema_branch.get(name=kind) - if node_schema.is_generic_schema: - return False - - for generic_kind in node_schema.inherit_from: - if generic_kind in KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED: - return True - return False - - -def _hash(value: str) -> str: - # Do not use builtin `hash` for lock names as due to randomization results would differ between - # different processes. - return hashlib.sha256(value.encode()).hexdigest() - - -def _get_kind_lock_names_on_object_mutation( - kind: str, branch: Branch, schema_branch: SchemaBranch, data: InputObjectType -) -> list[str]: - """ - Return objects kind for which we want to avoid concurrent mutation (create/update). Except for some specific kinds, - concurrent mutations are only allowed on non-main branch as objects validations will be performed at least when merging in main branch. - """ - - if not branch.is_default and not _should_kind_be_locked_on_any_branch(kind=kind, schema_branch=schema_branch): - return [] - - if kind == InfrahubKind.GRAPHQLQUERYGROUP: - # Lock on name as well to improve performances - try: - name = data.name.value - return [build_object_lock_name(kind + "." + _hash(name))] - except AttributeError: - # We might reach here if we are updating a CoreGraphQLQueryGroup without updating the name, - # in which case we would not need to lock. This is not supposed to happen as current `update` - # logic first fetches the node with its name. - return [] - - lock_kinds = _get_kinds_to_lock_on_object_mutation(kind, schema_branch) - lock_names = [build_object_lock_name(kind) for kind in lock_kinds] - return lock_names - - def _get_data_fields(data: InputObjectType) -> list[str]: return [field for field in data.keys() if field not in ["id", "hfid"]] diff --git a/backend/infrahub/lock.py b/backend/infrahub/lock.py index 5dad8c2224..4bd68c227f 100644 --- a/backend/infrahub/lock.py +++ b/backend/infrahub/lock.py @@ -5,6 +5,7 @@ import uuid from asyncio import Lock as LocalLock from asyncio import sleep +from contextvars import ContextVar from typing import TYPE_CHECKING import redis.asyncio as redis @@ -132,6 +133,7 @@ def __init__( self.lock_type: str = "multi" if self.in_multi else "individual" self.acquire_time: int | None = None self.event = asyncio.Event() + self._recursion_var: ContextVar[int | None] = ContextVar(f"infrahub_lock_recursion_{self.name}", default=None) if not self.connection or (self.use_local is None and name.startswith("local.")): self.use_local = True @@ -155,6 +157,11 @@ async def __aexit__( await self.release() async def acquire(self) -> None: + depth = self._recursion_var.get() + if depth is not None: + self._recursion_var.set(depth + 1) + return + with LOCK_ACQUIRE_TIME_METRICS.labels(self.name, self.lock_type).time(): if not self.use_local: await self.remote.acquire(token=f"{current_timestamp()}::{WORKER_IDENTITY}") @@ -162,14 +169,28 @@ async def acquire(self) -> None: await self.local.acquire() self.acquire_time = time.time_ns() self.event.clear() + self._recursion_var.set(1) async def release(self) -> None: - duration_ns = time.time_ns() - self.acquire_time - LOCK_RESERVE_TIME_METRICS.labels(self.name, self.lock_type).observe(duration_ns / 1000000000) + depth = self._recursion_var.get() + if depth is None: + raise RuntimeError("Lock release attempted without ownership context.") + + if depth > 1: + self._recursion_var.set(depth - 1) + return + + if self.acquire_time is not None: + duration_ns = time.time_ns() - self.acquire_time + LOCK_RESERVE_TIME_METRICS.labels(self.name, self.lock_type).observe(duration_ns / 1000000000) + self.acquire_time = None + if not self.use_local: await self.remote.release() else: self.local.release() + + self._recursion_var.set(None) self.event.set() async def locked(self) -> bool: diff --git a/backend/infrahub/lock_getter.py b/backend/infrahub/lock_getter.py new file mode 100644 index 0000000000..b8b9bbb3db --- /dev/null +++ b/backend/infrahub/lock_getter.py @@ -0,0 +1,118 @@ +import hashlib +from typing import TYPE_CHECKING + +from infrahub.core.node import Node +from infrahub.core.schema import GenericSchema +from infrahub.core.schema.schema_branch import SchemaBranch +from infrahub.lock import build_object_lock_name + +if TYPE_CHECKING: + from infrahub.core.relationship import RelationshipManager + + +def _get_kinds_to_lock_on_object_mutation(kind: str, schema_branch: SchemaBranch) -> list[str]: + """ + Return kinds for which we want to lock during creating / updating an object of a given schema node. + Lock should be performed on schema kind and its generics having a uniqueness_constraint defined. + If a generic uniqueness constraint is the same as the node schema one, + it means node schema overrided this constraint, in which case we only need to lock on the generic. + """ + + node_schema = schema_branch.get(name=kind, duplicate=False) + + schema_uc = None + kinds = [] + if node_schema.uniqueness_constraints: + kinds.append(node_schema.kind) + schema_uc = node_schema.uniqueness_constraints + + if isinstance(node_schema, GenericSchema): + return kinds + + generics_kinds = node_schema.inherit_from + + node_schema_kind_removed = False + for generic_kind in generics_kinds: + generic_uc = schema_branch.get(name=generic_kind, duplicate=False).uniqueness_constraints + if generic_uc: + kinds.append(generic_kind) + if not node_schema_kind_removed and generic_uc == schema_uc: + # Check whether we should remove original schema kind as it simply overrides uniqueness_constraint + # of a generic + kinds.pop(0) + node_schema_kind_removed = True + return kinds + + +def _hash(value: str) -> str: + # Do not use builtin `hash` for lock names as due to randomization results would differ between + # different processes. + return hashlib.sha256(value.encode()).hexdigest() + + +def get_lock_names_on_object_mutation(node: Node, schema_branch: SchemaBranch) -> list[str]: + """ + Return lock names for object on which we want to avoid concurrent mutation (create/update). + Lock names include kind, some generic kinds, resource pool ids, and values of attributes of corresponding uniqueness constraints. + """ + + lock_names: set[str] = set() + + # Check if node is using resource manager allocation via attributes + for attr_name in node.get_schema().attribute_names: + attribute = getattr(node, attr_name, None) + if attribute is not None and getattr(attribute, "from_pool", None) and "id" in attribute.from_pool: + lock_names.add(f"resource_pool.{attribute.from_pool['id']}") + + # Check if relationships allocate resources + for rel_name in node._relationships: + rel_manager: RelationshipManager = getattr(node, rel_name) + for rel in rel_manager._relationships: + if rel.from_pool and "id" in rel.from_pool: + lock_names.add(f"resource_pool.{rel.from_pool['id']}") + + lock_kinds = _get_kinds_to_lock_on_object_mutation(node.get_kind(), schema_branch) + for kind in lock_kinds: + schema = schema_branch.get(name=kind, duplicate=False) + ucs = schema.uniqueness_constraints + if ucs is None: + continue + + ucs_lock_names: list[str] = [] + uc_attributes_names = set() + + for uc in ucs: + uc_attributes_values = [] + # Keep only attributes constraints + for field_path in uc: + # Some attributes may exist in different uniqueness constraints, we de-duplicate them + if field_path in uc_attributes_names: + continue + + # Exclude relationships uniqueness constraints + schema_path = schema.parse_schema_path(path=field_path, schema=schema_branch) + if schema_path.related_schema is not None or schema_path.attribute_schema is None: + continue + + uc_attributes_names.add(field_path) + attr = getattr(node, schema_path.attribute_schema.name, None) + if attr is None or attr.value is None: + # `attr.value` being None corresponds to optional unique attribute. + # `attr` being None is not supposed to happen. + value_hashed = _hash("") + else: + value_hashed = _hash(str(attr.value)) + + uc_attributes_values.append(value_hashed) + + if uc_attributes_values: + uc_lock_name = ".".join(uc_attributes_values) + ucs_lock_names.append(uc_lock_name) + + if not ucs_lock_names: + continue + + partial_lock_name = kind + "." + ".".join(ucs_lock_names) + lock_names.add(build_object_lock_name(partial_lock_name)) + + return list(lock_names) diff --git a/backend/tests/unit/core/test_get_kinds_lock.py b/backend/tests/unit/core/test_get_kinds_lock.py index ed28288b7f..14e2a02e52 100644 --- a/backend/tests/unit/core/test_get_kinds_lock.py +++ b/backend/tests/unit/core/test_get_kinds_lock.py @@ -1,15 +1,15 @@ -from unittest.mock import patch +from copy import deepcopy -from infrahub import lock from infrahub.core import registry -from infrahub.core.constants.infrahubkind import GRAPHQLQUERY, GRAPHQLQUERYGROUP from infrahub.core.initialization import create_branch from infrahub.database import InfrahubDatabase -from infrahub.graphql.mutations.main import ( +from infrahub.lock_getter import ( _get_kinds_to_lock_on_object_mutation, _hash, + get_lock_names_on_object_mutation, ) from tests.helpers.test_app import TestInfrahubApp +from tests.node_creation import create_and_save class TestGetKindsLock(TestInfrahubApp): @@ -36,79 +36,54 @@ async def test_get_kinds_lock( schema_branch = registry.schema.get_schema_branch(name=default_branch.name) assert _get_kinds_to_lock_on_object_mutation(kind="BuiltinIPPrefix", schema_branch=schema_branch) == [] - async def test_lock_core_graphql_query_groups( + async def test_lock_other_branch( self, db: InfrahubDatabase, default_branch, - register_core_models_schema, client, + car_person_schema, ): - graphql_query = await client.create( - kind=GRAPHQLQUERY, - name="a_gql_query", - query="""mutation MyMutation { - InfrahubAccountTokenDelete(data: {id: "%s"}) { - ok - } - }""", - ) - await graphql_query.save() - - # Test create - with patch("infrahub.graphql.mutations.main.InfrahubMultiLock") as mock_infrahub_multi_lock: - group = await client.create(kind=GRAPHQLQUERYGROUP, name="a_gql_group", query=graphql_query) - await group.save() - mock_infrahub_multi_lock.assert_called_once_with( - lock_registry=lock.registry, locks=["global.object.CoreGraphQLQueryGroup." + _hash("a_gql_group")] - ) - - # Test upsert the same node - with patch("infrahub.graphql.mutations.main.InfrahubMultiLock") as mock_infrahub_multi_lock: - group = await client.create(kind=GRAPHQLQUERYGROUP, name="a_gql_group", query=graphql_query) - await group.save(allow_upsert=True) - - mock_infrahub_multi_lock.assert_called_with( - lock_registry=lock.registry, locks=["global.object.CoreGraphQLQueryGroup." + _hash("a_gql_group")] - ) - - # Test updating group name - with patch("infrahub.graphql.mutations.main.InfrahubMultiLock") as mock_infrahub_multi_lock: - group.name = "new_group_name" - await group.save() + other_branch = await create_branch(branch_name="other_branch", db=db) + schema_branch = registry.schema.get_schema_branch(name=other_branch.name) - mock_infrahub_multi_lock.assert_called_once_with( - lock_registry=lock.registry, locks=["global.object.CoreGraphQLQueryGroup." + _hash("new_group_name")] - ) + person = await create_and_save(db=db, schema="TestPerson", name="John", branch=other_branch) + assert get_lock_names_on_object_mutation(person, schema_branch=schema_branch) == [ + "global.object.TestPerson." + _hash("John") + ] - # Test updating other field not present in uniqueness constraint - with patch("infrahub.graphql.mutations.main.InfrahubMultiLock") as mock_infrahub_multi_lock: - query = ( - """mutation { - CoreGraphQLQueryGroupUpdate( - data: { - id: "%s" - label: { value: "new_label"} - } - ){ - ok - } - } - """ - % group.id - ) + async def test_lock_names_only_attributes( + self, + db: InfrahubDatabase, + default_branch, + client, + car_person_schema_unregistered, + ): + car_person_schema_unregistered = deepcopy(car_person_schema_unregistered) + car_person_schema_unregistered.nodes[0].uniqueness_constraints = [ + ["name__value", "color__value", "owner__name"] + ] + registry.schema.register_schema(schema=car_person_schema_unregistered, branch=default_branch.name) - result = await client.execute_graphql(query=query) - assert result["CoreGraphQLQueryGroupUpdate"]["ok"] is True + schema_branch = registry.schema.get_schema_branch(name=default_branch.name) + person = await create_and_save(db=db, schema="TestPerson", name="John") + car = await create_and_save(db=db, schema="TestCar", name="mercedes", color="blue", owner=person) + assert get_lock_names_on_object_mutation(car, schema_branch=schema_branch) == [ + "global.object.TestCar." + _hash("mercedes") + "." + _hash("blue") + ] - mock_infrahub_multi_lock.assert_not_called() + async def test_lock_names_optional_empty_attribute( + self, + db: InfrahubDatabase, + default_branch, + client, + car_person_schema_unregistered, + ): + car_person_schema_unregistered = deepcopy(car_person_schema_unregistered) + car_person_schema_unregistered.nodes[1].uniqueness_constraints = [["height__value"]] + registry.schema.register_schema(schema=car_person_schema_unregistered, branch=default_branch.name) - # Test lock onanother branch - other_branch = await create_branch(branch_name="other_branch", db=db) - with patch("infrahub.graphql.mutations.main.InfrahubMultiLock") as mock_infrahub_multi_lock: - group = await client.create( - kind=GRAPHQLQUERYGROUP, name="one_more_group", query=graphql_query, branch=other_branch.name - ) - await group.save() - mock_infrahub_multi_lock.assert_called_once_with( - lock_registry=lock.registry, locks=["global.object.CoreGraphQLQueryGroup." + _hash("one_more_group")] - ) + schema_branch = registry.schema.get_schema_branch(name=default_branch.name) + person = await create_and_save(db=db, schema="TestPerson", name="John") + assert get_lock_names_on_object_mutation(person, schema_branch=schema_branch) == [ + "global.object.TestPerson." + _hash("") + "." + _hash("John") + ] diff --git a/backend/tests/unit/test_lock.py b/backend/tests/unit/test_lock.py index abda514fa4..4519f4d011 100644 --- a/backend/tests/unit/test_lock.py +++ b/backend/tests/unit/test_lock.py @@ -60,3 +60,34 @@ def test_generate_name(): assert generate_name("simple.name", namespace="other") == "other.simple.name" assert generate_name("simple", namespace="other", local=True) == "local.other.simple" assert generate_name("simple", namespace="other", local=False) == "global.other.simple" + + +async def test_reentrant_lock_allows_nested_acquisitions(): + lock.initialize_lock(local_only=True) + + events: list[str] = [] + + async def reentrant_task() -> None: + async with lock.registry.get(name="resource_pool.test"): + events.append("outer acquired") + async with lock.registry.get(name="resource_pool.test"): + events.append("inner acquired") + await sleep(delay=0.1) + events.append("inner released") + await sleep(delay=0.1) + events.append("outer released") + + async def waiting_task() -> None: + await sleep(delay=0.05) + async with lock.registry.get(name="resource_pool.test"): + events.append("waiter acquired") + + await gather(reentrant_task(), waiting_task()) + + assert events == [ + "outer acquired", + "inner acquired", + "inner released", + "outer released", + "waiter acquired", + ]