Skip to content

Commit b72719c

Browse files
Iterate to make decrease of dimensionality work
1 parent d328c1f commit b72719c

37 files changed

+716
-387
lines changed

inference/core/workflows/execution_engine/introspection/schema_parser.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,16 @@ def parse_block_manifest(
6464
inputs_accepting_batches_and_scalars = set(
6565
manifest_type.get_parameters_accepting_batches_and_scalars()
6666
)
67+
inputs_enforcing_auto_batch_casting = set(
68+
manifest_type.get_parameters_enforcing_auto_batch_casting()
69+
)
6770
return parse_block_manifest_schema(
6871
schema=schema,
6972
inputs_dimensionality_offsets=inputs_dimensionality_offsets,
7073
dimensionality_reference_property=dimensionality_reference_property,
7174
inputs_accepting_batches=inputs_accepting_batches,
7275
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
76+
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
7377
)
7478

7579

@@ -79,6 +83,7 @@ def parse_block_manifest_schema(
7983
dimensionality_reference_property: Optional[str],
8084
inputs_accepting_batches: Set[str],
8185
inputs_accepting_batches_and_scalars: Set[str],
86+
inputs_enforcing_auto_batch_casting: Set[str],
8287
) -> BlockManifestMetadata:
8388
primitive_types = retrieve_primitives_from_schema(
8489
schema=schema,
@@ -89,6 +94,7 @@ def parse_block_manifest_schema(
8994
dimensionality_reference_property=dimensionality_reference_property,
9095
inputs_accepting_batches=inputs_accepting_batches,
9196
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
97+
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
9298
)
9399
return BlockManifestMetadata(
94100
primitive_types=primitive_types,
@@ -255,6 +261,7 @@ def retrieve_selectors_from_schema(
255261
dimensionality_reference_property: Optional[str],
256262
inputs_accepting_batches: Set[str],
257263
inputs_accepting_batches_and_scalars: Set[str],
264+
inputs_enforcing_auto_batch_casting: Set[str],
258265
) -> Dict[str, SelectorDefinition]:
259266
result = []
260267
for property_name, property_definition in schema[PROPERTIES_KEY].items():
@@ -277,6 +284,7 @@ def retrieve_selectors_from_schema(
277284
is_list_element=True,
278285
inputs_accepting_batches=inputs_accepting_batches,
279286
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
287+
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
280288
)
281289
elif property_definition.get(TYPE_KEY) == OBJECT_TYPE and isinstance(
282290
property_definition.get(ADDITIONAL_PROPERTIES_KEY), dict
@@ -290,6 +298,7 @@ def retrieve_selectors_from_schema(
290298
is_dict_element=True,
291299
inputs_accepting_batches=inputs_accepting_batches,
292300
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
301+
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
293302
)
294303
else:
295304
selector = retrieve_selectors_from_simple_property(
@@ -300,6 +309,7 @@ def retrieve_selectors_from_schema(
300309
is_dimensionality_reference_property=is_dimensionality_reference_property,
301310
inputs_accepting_batches=inputs_accepting_batches,
302311
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
312+
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
303313
)
304314
if selector is not None:
305315
result.append(selector)
@@ -314,6 +324,7 @@ def retrieve_selectors_from_simple_property(
314324
is_dimensionality_reference_property: bool,
315325
inputs_accepting_batches: Set[str],
316326
inputs_accepting_batches_and_scalars: Set[str],
327+
inputs_enforcing_auto_batch_casting: Set[str],
317328
is_list_element: bool = False,
318329
is_dict_element: bool = False,
319330
) -> Optional[SelectorDefinition]:
@@ -325,7 +336,10 @@ def retrieve_selectors_from_simple_property(
325336
if property_name in inputs_accepting_batches_and_scalars:
326337
points_to_batch = {True, False}
327338
else:
328-
points_to_batch = {property_name in inputs_accepting_batches}
339+
points_to_batch = {
340+
property_name in inputs_accepting_batches
341+
or property_name in inputs_enforcing_auto_batch_casting
342+
}
329343
else:
330344
points_to_batch = {declared_points_to_batch}
331345
allowed_references = [
@@ -359,6 +373,7 @@ def retrieve_selectors_from_simple_property(
359373
is_dimensionality_reference_property=is_dimensionality_reference_property,
360374
inputs_accepting_batches=inputs_accepting_batches,
361375
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
376+
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
362377
is_list_element=True,
363378
)
364379
if property_defines_union(property_definition=property_definition):
@@ -372,6 +387,7 @@ def retrieve_selectors_from_simple_property(
372387
is_dimensionality_reference_property=is_dimensionality_reference_property,
373388
inputs_accepting_batches=inputs_accepting_batches,
374389
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
390+
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
375391
)
376392
return None
377393

@@ -394,6 +410,7 @@ def retrieve_selectors_from_union_definition(
394410
is_dimensionality_reference_property: bool,
395411
inputs_accepting_batches: Set[str],
396412
inputs_accepting_batches_and_scalars: Set[str],
413+
inputs_enforcing_auto_batch_casting: Set[str],
397414
) -> Optional[SelectorDefinition]:
398415
union_types = (
399416
union_definition.get(ANY_OF_KEY, [])
@@ -410,6 +427,7 @@ def retrieve_selectors_from_union_definition(
410427
is_dimensionality_reference_property=is_dimensionality_reference_property,
411428
inputs_accepting_batches=inputs_accepting_batches,
412429
inputs_accepting_batches_and_scalars=inputs_accepting_batches_and_scalars,
430+
inputs_enforcing_auto_batch_casting=inputs_enforcing_auto_batch_casting,
413431
is_list_element=is_list_element,
414432
)
415433
if result is None:

inference/core/workflows/execution_engine/v1/compiler/entities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,9 @@ class StepNode(ExecutionGraphNode):
230230
child_execution_branches: Dict[str, str] = field(default_factory=dict)
231231
execution_branches_impacting_inputs: Set[str] = field(default_factory=set)
232232
batch_oriented_parameters: Set[str] = field(default_factory=set)
233-
auto_batch_casting_lineage_supports: Dict[str, AutoBatchCastingConfig] = field(default_factory=dict)
233+
auto_batch_casting_lineage_supports: Dict[str, AutoBatchCastingConfig] = field(
234+
default_factory=dict
235+
)
234236
step_execution_dimensionality: int = 0
235237

236238
def controls_flow(self) -> bool:

0 commit comments

Comments
 (0)