Skip to content

Commit 16a210f

Browse files
committed
fix(backend): lock before beginning transactions
WIP: still need to fix the uniqueness constraint lock on update. Signed-off-by: Fatih Acar <[email protected]>
1 parent bde76ed commit 16a210f

File tree

5 files changed

+274
-76
lines changed

5 files changed

+274
-76
lines changed

backend/infrahub/core/node/__init__.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,13 @@ async def init(
257257

258258
return cls(**attrs)
259259

260-
async def handle_pool(self, db: InfrahubDatabase, attribute: BaseAttribute, errors: list) -> None:
260+
async def handle_pool(
261+
self,
262+
db: InfrahubDatabase,
263+
attribute: BaseAttribute,
264+
errors: list,
265+
allocate_resources: bool = True,
266+
) -> None:
261267
"""Evaluate if a resource has been requested from a pool and apply the resource
262268
263269
This method only works on number pools, currently Integer is the only type that has the from_pool
@@ -268,7 +274,7 @@ async def handle_pool(self, db: InfrahubDatabase, attribute: BaseAttribute, erro
268274
attribute.from_pool = {"id": attribute.schema.parameters.number_pool_id}
269275
attribute.is_default = False
270276

271-
if not attribute.from_pool:
277+
if not attribute.from_pool or not allocate_resources:
272278
return
273279

274280
try:
@@ -426,7 +432,12 @@ async def handle_object_template(self, fields: dict, db: InfrahubDatabase, error
426432
elif relationship_peers := await relationship.get_peers(db=db):
427433
fields[relationship_name] = [{"id": peer_id} for peer_id in relationship_peers]
428434

429-
async def _process_fields(self, fields: dict, db: InfrahubDatabase) -> None:
435+
async def _process_fields(
436+
self,
437+
fields: dict,
438+
db: InfrahubDatabase,
439+
process_pools: bool = True,
440+
) -> None:
430441
errors = []
431442

432443
if "_source" in fields.keys():
@@ -480,7 +491,7 @@ async def _process_fields(self, fields: dict, db: InfrahubDatabase) -> None:
480491
# Generate Attribute and Relationship and assign them
481492
# -------------------------------------------
482493
errors.extend(await self._process_fields_relationships(fields=fields, db=db))
483-
errors.extend(await self._process_fields_attributes(fields=fields, db=db))
494+
errors.extend(await self._process_fields_attributes(fields=fields, db=db, process_pools=process_pools))
484495

485496
if errors:
486497
raise ValidationError(errors)
@@ -517,7 +528,12 @@ async def _process_fields_relationships(self, fields: dict, db: InfrahubDatabase
517528

518529
return errors
519530

520-
async def _process_fields_attributes(self, fields: dict, db: InfrahubDatabase) -> list[ValidationError]:
531+
async def _process_fields_attributes(
532+
self,
533+
fields: dict,
534+
db: InfrahubDatabase,
535+
process_pools: bool,
536+
) -> list[ValidationError]:
521537
errors: list[ValidationError] = []
522538

523539
for attr_schema in self._schema.attributes:
@@ -542,9 +558,15 @@ async def _process_fields_attributes(self, fields: dict, db: InfrahubDatabase) -
542558
)
543559
if not self._existing:
544560
attribute: BaseAttribute = getattr(self, attr_schema.name)
545-
await self.handle_pool(db=db, attribute=attribute, errors=errors)
561+
await self.handle_pool(
562+
db=db,
563+
attribute=attribute,
564+
errors=errors,
565+
allocate_resources=process_pools,
566+
)
546567

547-
attribute.validate(value=attribute.value, name=attribute.name, schema=attribute.schema)
568+
if process_pools or attribute.schema.kind != "NumberPool" or attribute.from_pool is None:
569+
attribute.validate(value=attribute.value, name=attribute.name, schema=attribute.schema)
548570
except ValidationError as exc:
549571
errors.append(exc)
550572

@@ -672,7 +694,13 @@ async def process_label(self, db: InfrahubDatabase | None = None) -> None: # no
672694
self.label.value = " ".join([word.title() for word in self.name.value.split("_")])
673695
self.label.is_default = False
674696

675-
async def new(self, db: InfrahubDatabase, id: str | None = None, **kwargs: Any) -> Self:
697+
async def new(
698+
self,
699+
db: InfrahubDatabase,
700+
id: str | None = None,
701+
process_pools: bool = True,
702+
**kwargs: Any,
703+
) -> Self:
676704
if id and not is_valid_uuid(id):
677705
raise ValidationError({"id": f"{id} is not a valid UUID"})
678706
if id:
@@ -682,7 +710,7 @@ async def new(self, db: InfrahubDatabase, id: str | None = None, **kwargs: Any)
682710

683711
self.id = id or str(UUIDT())
684712

685-
await self._process_fields(db=db, fields=kwargs)
713+
await self._process_fields(db=db, fields=kwargs, process_pools=process_pools)
686714
await self._process_macros(db=db)
687715

688716
return self

backend/infrahub/core/node/create.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Mapping
3+
from typing import TYPE_CHECKING, Any, Mapping, Sequence
44

5+
from infrahub import lock
56
from infrahub.core import registry
67
from infrahub.core.constants import RelationshipCardinality, RelationshipKind
78
from infrahub.core.constraint.node.runner import NodeConstraintRunner
@@ -10,6 +11,8 @@
1011
from infrahub.core.node.save import run_constraints_and_save
1112
from infrahub.core.protocols import CoreObjectTemplate
1213
from infrahub.dependencies.registry import get_component_registry
14+
from infrahub.lock import InfrahubMultiLock
15+
from infrahub.lock_getter import get_lock_names_on_object_mutation
1316

1417
if TYPE_CHECKING:
1518
from infrahub.core.branch import Branch
@@ -142,24 +145,22 @@ async def refresh_for_profile_update(
142145

143146

144147
async def _do_create_node(
145-
node_class: type[Node],
148+
obj: Node,
146149
db: InfrahubDatabase,
147-
data: dict,
148-
schema: NonGenericSchemaTypes,
149-
fields_to_validate: list,
150150
branch: Branch,
151+
fields_to_validate: list[str],
151152
node_constraint_runner: NodeConstraintRunner,
153+
lock_names: Sequence[str],
152154
) -> Node:
153-
obj = await node_class.init(db=db, schema=schema, branch=branch)
154-
await obj.new(db=db, **data)
155-
156155
await run_constraints_and_save(
157156
node=obj,
158157
node_constraint_runner=node_constraint_runner,
159158
fields_to_validate=fields_to_validate,
160159
fields_to_save=[],
161160
db=db,
162161
branch=branch,
162+
lock_names=lock_names,
163+
manage_lock=False,
163164
)
164165

165166
object_template = await obj.get_object_template(db=db)
@@ -170,6 +171,7 @@ async def _do_create_node(
170171
template=object_template,
171172
obj=obj,
172173
fields=fields_to_validate,
174+
constraint_runner=node_constraint_runner,
173175
)
174176
return obj
175177

@@ -183,35 +185,39 @@ async def create_node(
183185
"""Create a node in the database if constraint checks succeed."""
184186

185187
component_registry = get_component_registry()
186-
node_constraint_runner = await component_registry.get_component(
187-
NodeConstraintRunner, db=db.start_session() if not db.is_transaction else db, branch=branch
188-
)
189188
node_class = Node
190189
if schema.kind in registry.node:
191190
node_class = registry.node[schema.kind]
192191

193192
fields_to_validate = list(data)
194-
if db.is_transaction:
195-
obj = await _do_create_node(
196-
node_class=node_class,
197-
node_constraint_runner=node_constraint_runner,
198-
db=db,
199-
schema=schema,
193+
preview_obj = await node_class.init(db=db, schema=schema, branch=branch)
194+
await preview_obj.new(db=db, process_pools=False, **data)
195+
schema_branch = db.schema.get_schema_branch(name=branch.name)
196+
lock_names = get_lock_names_on_object_mutation(node=preview_obj, branch=branch, schema_branch=schema_branch)
197+
198+
async def _persist(current_db: InfrahubDatabase) -> Node:
199+
node_constraint_runner = await component_registry.get_component(
200+
NodeConstraintRunner, db=current_db, branch=branch
201+
)
202+
obj = await node_class.init(db=current_db, schema=schema, branch=branch)
203+
await obj.new(db=current_db, **data)
204+
205+
return await _do_create_node(
206+
obj=obj,
207+
db=current_db,
200208
branch=branch,
201209
fields_to_validate=fields_to_validate,
202-
data=data,
210+
node_constraint_runner=node_constraint_runner,
211+
lock_names=lock_names,
203212
)
204-
else:
205-
async with db.start_transaction() as dbt:
206-
obj = await _do_create_node(
207-
node_class=node_class,
208-
node_constraint_runner=node_constraint_runner,
209-
db=dbt,
210-
schema=schema,
211-
branch=branch,
212-
fields_to_validate=fields_to_validate,
213-
data=data,
214-
)
213+
214+
obj: Node
215+
async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names):
216+
if db.is_transaction:
217+
obj = await _persist(db)
218+
else:
219+
async with db.start_transaction() as dbt:
220+
obj = await _persist(dbt)
215221

216222
if await get_profile_ids(db=db, obj=obj):
217223
obj = await refresh_for_profile_update(db=db, branch=branch, schema=schema, obj=obj)

backend/infrahub/core/node/save.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Sequence
2+
13
from infrahub import lock
24
from infrahub.core.branch import Branch
35
from infrahub.core.constraint.node.runner import NodeConstraintRunner
@@ -15,11 +17,41 @@ async def run_constraints_and_save(
1517
db: InfrahubDatabase,
1618
branch: Branch,
1719
skip_uniqueness_check: bool = False,
20+
lock_names: Sequence[str] | None = None,
21+
manage_lock: bool = True,
1822
) -> None:
23+
"""Validate a node and persist it, optionally reusing an existing lock context.
24+
25+
Args:
26+
node: The node instance to validate and persist.
27+
node_constraint_runner: Runner executing node-level constraints.
28+
fields_to_validate: Field names that must be validated.
29+
fields_to_save: Field names that must be persisted.
30+
db: Database connection or transaction to use for persistence.
31+
branch: Branch associated with the mutation.
32+
skip_uniqueness_check: Whether to skip uniqueness constraints.
33+
lock_names: Precomputed lock identifiers to reuse when ``manage_lock`` is False.
34+
manage_lock: Whether this helper should acquire and release locks itself.
35+
"""
36+
37+
if not manage_lock and lock_names is None:
38+
raise ValueError("lock_names must be provided when manage_lock is False")
39+
1940
schema_branch = db.schema.get_schema_branch(name=branch.name)
20-
lock_names = get_lock_names_on_object_mutation(node=node, branch=branch, schema_branch=schema_branch)
21-
async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names):
41+
locks = (
42+
list(lock_names)
43+
if lock_names is not None
44+
else get_lock_names_on_object_mutation(node=node, branch=branch, schema_branch=schema_branch)
45+
)
46+
47+
async def _persist() -> None:
2248
await node_constraint_runner.check(
2349
node=node, field_filters=fields_to_validate, skip_uniqueness_check=skip_uniqueness_check
2450
)
2551
await node.save(db=db, fields=fields_to_save)
52+
53+
if manage_lock:
54+
async with InfrahubMultiLock(lock_registry=lock.registry, locks=locks):
55+
await _persist()
56+
else:
57+
await _persist()

0 commit comments

Comments
 (0)