diff --git a/inference/core/workflows/execution_engine/v1/compiler/graph_constructor.py b/inference/core/workflows/execution_engine/v1/compiler/graph_constructor.py index 8308a62cbf..ce7aec61ec 100644 --- a/inference/core/workflows/execution_engine/v1/compiler/graph_constructor.py +++ b/inference/core/workflows/execution_engine/v1/compiler/graph_constructor.py @@ -1633,13 +1633,23 @@ def get_input_data_lineage_excluding_auto_batch_casting( for property_name, input_definition in input_data.items(): if property_name in scalar_parameters_to_be_batched: continue - new_lineages_detected_within_property_data = get_lineage_for_input_property( - step_name=step_name, - property_name=property_name, - input_definition=input_definition, - lineage_deduplication_set=lineage_deduplication_set, - ) - lineages.extend(new_lineages_detected_within_property_data) + if input_definition.is_compound_input(): + new_lineages_detected_within_property_data = ( + get_lineage_from_compound_input( + step_name=step_name, + property_name=property_name, + input_definition=input_definition, + lineage_deduplication_set=lineage_deduplication_set, + ) + ) + lineages.extend(new_lineages_detected_within_property_data) + else: + if input_definition.is_batch_oriented(): + lineage = input_definition.data_lineage + lineage_id = identify_lineage(lineage=lineage) + if lineage_id not in lineage_deduplication_set: + lineage_deduplication_set.add(lineage_id) + lineages.append(lineage) if not lineages: return lineages verify_lineages(step_name=step_name, detected_lineages=lineages) @@ -1729,9 +1739,13 @@ def get_lineage_from_compound_input( def verify_lineages(step_name: str, detected_lineages: List[List[str]]) -> None: - lineages_by_length = defaultdict(list) + lineages_by_length = {} for lineage in detected_lineages: - lineages_by_length[len(lineage)].append(lineage) + lineage_len = len(lineage) + if lineage_len not in lineages_by_length: + lineages_by_length[lineage_len] = [lineage] + else: + lineages_by_length[lineage_len].append(lineage) if len(lineages_by_length) > 2: raise StepInputLineageError( public_message=f"Input data provided for step: `{step_name}` comes with lineages at more than two "