1
1
from __future__ import annotations
2
2
3
- from typing import TYPE_CHECKING , Any , Mapping
3
+ from typing import TYPE_CHECKING , Any , Mapping , Sequence
4
4
5
+ from infrahub import lock
5
6
from infrahub .core import registry
6
7
from infrahub .core .constants import RelationshipCardinality , RelationshipKind
7
8
from infrahub .core .constraint .node .runner import NodeConstraintRunner
10
11
from infrahub .core .node .save import run_constraints_and_save
11
12
from infrahub .core .protocols import CoreObjectTemplate
12
13
from infrahub .dependencies .registry import get_component_registry
14
+ from infrahub .lock import InfrahubMultiLock
15
+ from infrahub .lock_getter import get_lock_names_on_object_mutation
13
16
14
17
if TYPE_CHECKING :
15
18
from infrahub .core .branch import Branch
@@ -142,24 +145,22 @@ async def refresh_for_profile_update(
142
145
143
146
144
147
async def _do_create_node (
145
- node_class : type [ Node ] ,
148
+ obj : Node ,
146
149
db : InfrahubDatabase ,
147
- data : dict ,
148
- schema : NonGenericSchemaTypes ,
149
- fields_to_validate : list ,
150
150
branch : Branch ,
151
+ fields_to_validate : list [str ],
151
152
node_constraint_runner : NodeConstraintRunner ,
153
+ lock_names : Sequence [str ],
152
154
) -> Node :
153
- obj = await node_class .init (db = db , schema = schema , branch = branch )
154
- await obj .new (db = db , ** data )
155
-
156
155
await run_constraints_and_save (
157
156
node = obj ,
158
157
node_constraint_runner = node_constraint_runner ,
159
158
fields_to_validate = fields_to_validate ,
160
159
fields_to_save = [],
161
160
db = db ,
162
161
branch = branch ,
162
+ lock_names = lock_names ,
163
+ manage_lock = False ,
163
164
)
164
165
165
166
object_template = await obj .get_object_template (db = db )
@@ -170,6 +171,7 @@ async def _do_create_node(
170
171
template = object_template ,
171
172
obj = obj ,
172
173
fields = fields_to_validate ,
174
+ constraint_runner = node_constraint_runner ,
173
175
)
174
176
return obj
175
177
@@ -183,35 +185,39 @@ async def create_node(
183
185
"""Create a node in the database if constraint checks succeed."""
184
186
185
187
component_registry = get_component_registry ()
186
- node_constraint_runner = await component_registry .get_component (
187
- NodeConstraintRunner , db = db .start_session () if not db .is_transaction else db , branch = branch
188
- )
189
188
node_class = Node
190
189
if schema .kind in registry .node :
191
190
node_class = registry .node [schema .kind ]
192
191
193
192
fields_to_validate = list (data )
194
- if db .is_transaction :
195
- obj = await _do_create_node (
196
- node_class = node_class ,
197
- node_constraint_runner = node_constraint_runner ,
198
- db = db ,
199
- schema = schema ,
193
+ preview_obj = await node_class .init (db = db , schema = schema , branch = branch )
194
+ await preview_obj .new (db = db , process_pools = False , ** data )
195
+ schema_branch = db .schema .get_schema_branch (name = branch .name )
196
+ lock_names = get_lock_names_on_object_mutation (node = preview_obj , branch = branch , schema_branch = schema_branch )
197
+
198
+ async def _persist (current_db : InfrahubDatabase ) -> Node :
199
+ node_constraint_runner = await component_registry .get_component (
200
+ NodeConstraintRunner , db = current_db , branch = branch
201
+ )
202
+ obj = await node_class .init (db = current_db , schema = schema , branch = branch )
203
+ await obj .new (db = current_db , ** data )
204
+
205
+ return await _do_create_node (
206
+ obj = obj ,
207
+ db = current_db ,
200
208
branch = branch ,
201
209
fields_to_validate = fields_to_validate ,
202
- data = data ,
210
+ node_constraint_runner = node_constraint_runner ,
211
+ lock_names = lock_names ,
203
212
)
204
- else :
205
- async with db .start_transaction () as dbt :
206
- obj = await _do_create_node (
207
- node_class = node_class ,
208
- node_constraint_runner = node_constraint_runner ,
209
- db = dbt ,
210
- schema = schema ,
211
- branch = branch ,
212
- fields_to_validate = fields_to_validate ,
213
- data = data ,
214
- )
213
+
214
+ obj : Node
215
+ async with InfrahubMultiLock (lock_registry = lock .registry , locks = lock_names ):
216
+ if db .is_transaction :
217
+ obj = await _persist (db )
218
+ else :
219
+ async with db .start_transaction () as dbt :
220
+ obj = await _persist (dbt )
215
221
216
222
if await get_profile_ids (db = db , obj = obj ):
217
223
obj = await refresh_for_profile_update (db = db , branch = branch , schema = schema , obj = obj )
0 commit comments