Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Aug 25, 2025

⚡️ This pull request contains optimizations for PR #1504

If you approve this dependent PR, these changes will be merged into the original PR branch feature/try-to-beat-the-limitation-of-ee-in-terms-of-singular-elements-pushed-into-batch-inputs.

This PR will be automatically closed if the original PR is merged.


📄 37% (0.37x) speedup for construct_simd_step_input in inference/core/workflows/execution_engine/v1/executor/execution_data_manager/step_input_assembler.py

⏱️ Runtime : 1.99 milliseconds 1.46 milliseconds (best of 40 runs)

📝 Explanation and details

The optimized code achieves a 36% speedup through a single but impactful conditional check optimization in the prepare_parameters function.

Key Optimization:
The main performance improvement comes from adding an if empty_indices: check before executing expensive list comprehension and data removal operations:

# Original: Always executes these expensive operations
indices = [e for e in indices if e not in empty_indices]
result = remove_indices(value=result, indices=empty_indices)

# Optimized: Only executes when empty_indices is non-empty
if empty_indices:
    indices = [e for e in indices if e not in empty_indices]
    result = remove_indices(value=result, indices=empty_indices)

Why this optimization works:

  • In many test cases, empty_indices is an empty set, making the filtering operations unnecessary
  • The list comprehension [e for e in indices if e not in empty_indices] has O(n*m) complexity where n=len(indices) and m=len(empty_indices)
  • remove_indices() recursively processes nested data structures, which is expensive even for empty removal sets
  • By avoiding these operations when empty_indices is empty, we eliminate significant computational overhead

Performance impact by test case type:

  • Large batch inputs see the biggest gains (43-107% faster) because they avoid expensive O(n) operations on large datasets when no filtering is needed
  • Basic test cases show consistent 15-25% improvements from avoiding unnecessary operations
  • Edge cases with actual empty elements may see minimal or slightly negative impact (0.5% slower) due to the additional conditional check, but this is negligible compared to the gains in common cases

This optimization is particularly effective because most workflow executions don't have empty batch elements that need filtering, making the conditional check a highly beneficial guard against unnecessary work.

Correctness verification report:

Test Status
⏪ Replay Tests 🔘 None Found
⚙️ Existing Unit Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
🌀 Generated Regression Tests 19 Passed
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union

# imports
import pytest  # used for our unit tests
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.step_input_assembler import \
    construct_simd_step_input


class ExecutionEngineRuntimeError(Exception):
    def __init__(self, public_message, context=None):
        super().__init__(public_message)
        self.public_message = public_message
        self.context = context

# For batch indices, we'll use tuples of ints for simplicity
DynamicBatchIndex = Tuple[int, ...]


class Batch:
    def __init__(self, data: List[Any], indices: List[DynamicBatchIndex]):
        self.data = data
        self.indices = indices

    def iter_with_indices(self):
        return zip(self.indices, self.data)

    def remove_by_indices(self, indices_to_remove: Set[DynamicBatchIndex]):
        new_data = []
        new_indices = []
        for idx, dat in zip(self.indices, self.data):
            if idx not in indices_to_remove:
                new_data.append(dat)
                new_indices.append(idx)
        return Batch(new_data, new_indices)

    def __eq__(self, other):
        if not isinstance(other, Batch):
            return False
        return self.data == other.data and self.indices == other.indices

    def __repr__(self):
        return f"Batch(data={self.data}, indices={self.indices})"


@dataclass(frozen=True)
class ParameterSpecification:
    parameter_name: str
    nested_element_key: Optional[str] = None

class NodeInputCategory:
    pass

@dataclass(frozen=True)
class StepInputDefinition:
    parameter_specification: ParameterSpecification
    category: NodeInputCategory
    _batch_oriented: bool = False
    _dimensionality: int = 0
    _points_to_input: bool = True
    _points_to_step_output: bool = False
    _value: Any = None
    _selector: Optional[str] = None
    _data_lineage: Optional[List[int]] = None

    @classmethod
    def is_compound_input(cls) -> bool:
        return False

    def is_batch_oriented(self) -> bool:
        return self._batch_oriented

    def get_dimensionality(self) -> int:
        return self._dimensionality

    def points_to_input(self) -> bool:
        return self._points_to_input

    def points_to_step_output(self) -> bool:
        return self._points_to_step_output

    @property
    def selector(self):
        return self._selector

    @property
    def data_lineage(self):
        return self._data_lineage

    @property
    def value(self):
        return self._value

@dataclass(frozen=True)
class StaticStepInputDefinition(StepInputDefinition):
    pass

@dataclass(frozen=True)
class DynamicStepInputDefinition(StepInputDefinition):
    pass

@dataclass(frozen=True)
class CompoundStepInputDefinition:
    # For this stub, compound inputs are a list or dict of StepInputDefinition
    _definitions: Union[List[StepInputDefinition], Dict[str, StepInputDefinition]]
    _list: bool = False

    @classmethod
    def is_compound_input(cls) -> bool:
        return True

    def represents_list_of_inputs(self) -> bool:
        return self._list

    def iterate_through_definitions(self):
        if self._list:
            return iter(self._definitions)
        else:
            return iter(self._definitions.values())

    def get_dimensionality(self) -> int:
        dims = []
        for d in self.iterate_through_definitions():
            dims.append(d.get_dimensionality())
        dims.append(0)
        return max(dims)

@dataclass
class StepNode:
    name: str
    input_data: Dict[str, Any]
    execution_branches_impacting_inputs: List[str]
    step_execution_dimensionality: int
    step_manifest: Any
    auto_batch_casting_lineage_supports: Dict[str, Any]

class BatchModeSIMDStepInput:
    def __init__(self, indices, parameters):
        self.indices = indices
        self.parameters = parameters

    def __eq__(self, other):
        if not isinstance(other, BatchModeSIMDStepInput):
            return False
        return self.indices == other.indices and self.parameters == other.parameters

    def __repr__(self):
        return f"BatchModeSIMDStepInput(indices={self.indices}, parameters={self.parameters})"

# Dummy "manifest" class
class DummyManifest:
    @classmethod
    def accepts_batch_input(cls):
        return True

    @classmethod
    def accepts_empty_values(cls):
        return False

# Dummy dynamic batches manager
class DynamicBatchesManager:
    def __init__(self, lineage_to_indices: Dict[Tuple[int, ...], List[DynamicBatchIndex]]):
        self.lineage_to_indices = lineage_to_indices

    def get_indices_for_data_lineage(self, lineage):
        return self.lineage_to_indices.get(tuple(lineage), [])

# Dummy execution cache
class ExecutionCache:
    def __init__(self, step_outputs: Dict[str, Any], batch_outputs: Dict[str, List[Any]]):
        self.step_outputs = step_outputs
        self.batch_outputs = batch_outputs

    def get_non_batch_output(self, selector):
        return self.step_outputs.get(selector, None)

    def get_batch_output(self, selector, batch_elements_indices, mask):
        # Return the output for the given indices, optionally filtered by mask
        data = self.batch_outputs.get(selector, [])
        if mask is not None:
            return [
                val if idx in mask else None
                for idx, val in zip(batch_elements_indices, data)
            ]
        else:
            return [val for idx, val in zip(batch_elements_indices, data)]

# Dummy branching manager
class BranchingManager:
    def __init__(self, masks: Dict[str, Union[Set[DynamicBatchIndex], bool]]):
        self._masks = masks
        self._batch_compatibility = {
            branch_name: not isinstance(mask, bool)
            for branch_name, mask in masks.items()
        }

    def get_mask(self, execution_branch: str) -> Union[Set[DynamicBatchIndex], bool]:
        return self._masks[execution_branch]

    def is_execution_branch_batch_oriented(self, execution_branch: str) -> bool:
        return self._batch_compatibility[execution_branch]

    def is_execution_branch_registered(self, execution_branch: str) -> bool:
        return execution_branch in self._masks
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.step_input_assembler import \
    construct_simd_step_input

# --- UNIT TESTS ---

# 1. BASIC TEST CASES

def test_single_batch_input_basic():
    """
    Test a single batch input with one execution branch and straightforward mask.
    """
    # Setup
    parameter = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="input1"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _points_to_step_output=False,
        _selector="input1",
        _data_lineage=[0]
    )
    step_node = StepNode(
        name="step1",
        input_data={"input1": parameter},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {"input1": [10, 20, 30]}
    masks = {"branch1": {(0,), (1,), (2,)}}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): [(0,), (1,), (2,)]})
    execution_cache = ExecutionCache({}, {})

    # Execute
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 24.0μs -> 20.8μs (15.4% faster)

    # Assert
    expected_batch = Batch([10, 20, 30], [(0,), (1,), (2,)])
    expected = BatchModeSIMDStepInput(indices=[(0,), (1,), (2,)], parameters={"input1": expected_batch})

def test_multiple_inputs_basic():
    """
    Test multiple batch inputs with matching indices and masks.
    """
    param1 = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="input1"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="input1",
        _data_lineage=[0]
    )
    param2 = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="input2"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="input2",
        _data_lineage=[0]
    )
    step_node = StepNode(
        name="step2",
        input_data={"input1": param1, "input2": param2},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {"input1": [1, 2], "input2": [3, 4]}
    masks = {"branch1": {(0,), (1,)}}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): [(0,), (1,)]})
    execution_cache = ExecutionCache({}, {})

    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 28.9μs -> 24.6μs (17.7% faster)

    expected = BatchModeSIMDStepInput(
        indices=[(0,), (1,)],
        parameters={
            "input1": Batch([1, 2], [(0,), (1,)]),
            "input2": Batch([3, 4], [(0,), (1,)]),
        }
    )


def test_empty_input_list():
    """
    Test with an empty input list, should result in empty indices and values.
    """
    param = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="input1"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="input1",
        _data_lineage=[0]
    )
    step_node = StepNode(
        name="step_empty",
        input_data={"input1": param},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {"input1": []}
    masks = {"branch1": set()}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): []})
    execution_cache = ExecutionCache({}, {})

    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 21.9μs -> 19.1μs (14.1% faster)
    expected = BatchModeSIMDStepInput(indices=[], parameters={"input1": Batch([], [])})

def test_mask_removes_some_elements():
    """
    Test where the mask removes some elements from the batch.
    """
    param = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="input1"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="input1",
        _data_lineage=[0]
    )
    step_node = StepNode(
        name="step_mask",
        input_data={"input1": param},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {"input1": [5, 6, 7, 8]}
    masks = {"branch1": {(1,), (3,)}}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): [(0,), (1,), (2,), (3,)]})
    execution_cache = ExecutionCache({}, {})

    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 23.8μs -> 24.1μs (1.21% slower)
    # Only indices (1,) and (3,) should remain, others set to None and removed
    expected_batch = Batch([None, 6, None, 8], [(0,), (1,), (2,), (3,)]).remove_by_indices({(0,), (2,)})
    expected = BatchModeSIMDStepInput(indices=[(1,), (3,)], parameters={"input1": expected_batch})

def test_compound_input_dict():
    """
    Test a compound input as a dict of two batch parameters.
    """
    param1 = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="inputA", nested_element_key="A"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="inputA",
        _data_lineage=[0]
    )
    param2 = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="inputB", nested_element_key="B"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="inputB",
        _data_lineage=[0]
    )
    compound = CompoundStepInputDefinition({"A": param1, "B": param2}, _list=False)
    step_node = StepNode(
        name="step_compound",
        input_data={"compound": compound},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {"inputA": [1, 2], "inputB": [10, 20]}
    masks = {"branch1": {(0,), (1,)}}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): [(0,), (1,)]})
    execution_cache = ExecutionCache({}, {})

    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 34.2μs -> 28.2μs (21.1% faster)
    expected = BatchModeSIMDStepInput(
        indices=[(0,), (1,)],
        parameters={
            "compound": {
                "A": Batch([1, 2], [(0,), (1,)]),
                "B": Batch([10, 20], [(0,), (1,)]),
            }
        }
    )

def test_compound_input_list():
    """
    Test a compound input as a list of batch parameters.
    """
    param1 = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="inputA"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="inputA",
        _data_lineage=[0]
    )
    param2 = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="inputB"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="inputB",
        _data_lineage=[0]
    )
    compound = CompoundStepInputDefinition([param1, param2], _list=True)
    step_node = StepNode(
        name="step_compound_list",
        input_data={"compound": compound},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {"inputA": [1, 2], "inputB": [10, 20]}
    masks = {"branch1": {(0,), (1,)}}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): [(0,), (1,)]})
    execution_cache = ExecutionCache({}, {})

    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 32.2μs -> 27.2μs (18.4% faster)
    expected = BatchModeSIMDStepInput(
        indices=[(0,), (1,)],
        parameters={
            "compound": [
                Batch([1, 2], [(0,), (1,)]),
                Batch([10, 20], [(0,), (1,)]),
            ]
        }
    )



def test_large_batch_input():
    """
    Test with a large batch input (size 1000).
    """
    N = 1000
    param = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="input1"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="input1",
        _data_lineage=[0]
    )
    step_node = StepNode(
        name="step_large",
        input_data={"input1": param},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {"input1": list(range(N))}
    masks = {"branch1": set((i,) for i in range(N))}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): [(i,) for i in range(N)]})
    execution_cache = ExecutionCache({}, {})

    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 346μs -> 241μs (43.6% faster)
    expected = BatchModeSIMDStepInput(
        indices=[(i,) for i in range(N)],
        parameters={"input1": Batch(list(range(N)), [(i,) for i in range(N)])}
    )

def test_large_compound_input_dict():
    """
    Test a large compound input dict with 10 keys, each with a batch of size 100.
    """
    N = 100
    keys = [f"k{i}" for i in range(10)]
    params = {
        k: DynamicStepInputDefinition(
            parameter_specification=ParameterSpecification(parameter_name=k, nested_element_key=k),
            category=NodeInputCategory(),
            _batch_oriented=True,
            _dimensionality=1,
            _points_to_input=True,
            _selector=k,
            _data_lineage=[0]
        )
        for k in keys
    }
    compound = CompoundStepInputDefinition(params, _list=False)
    step_node = StepNode(
        name="step_large_compound",
        input_data={"compound": compound},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {k: list(range(N)) for k in keys}
    masks = {"branch1": set((i,) for i in range(N))}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): [(i,) for i in range(N)]})
    execution_cache = ExecutionCache({}, {})

    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 317μs -> 222μs (42.7% faster)
    expected = BatchModeSIMDStepInput(
        indices=[(i,) for i in range(N)],
        parameters={
            "compound": {
                k: Batch(list(range(N)), [(i,) for i in range(N)]) for k in keys
            }
        }
    )

def test_large_mask_removal():
    """
    Test a large batch input with a mask that removes half the elements.
    """
    N = 1000
    param = DynamicStepInputDefinition(
        parameter_specification=ParameterSpecification(parameter_name="input1"),
        category=NodeInputCategory(),
        _batch_oriented=True,
        _dimensionality=1,
        _points_to_input=True,
        _selector="input1",
        _data_lineage=[0]
    )
    step_node = StepNode(
        name="step_large_mask",
        input_data={"input1": param},
        execution_branches_impacting_inputs=["branch1"],
        step_execution_dimensionality=1,
        step_manifest=DummyManifest,
        auto_batch_casting_lineage_supports={}
    )
    runtime_parameters = {"input1": list(range(N))}
    mask_indices = set((i,) for i in range(0, N, 2))  # keep even indices
    masks = {"branch1": mask_indices}
    branching_manager = BranchingManager(masks)
    dynamic_batches_manager = DynamicBatchesManager({(0,): [(i,) for i in range(N)]})
    execution_cache = ExecutionCache({}, {})

    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 324μs -> 326μs (0.473% slower)
    kept_indices = [(i,) for i in range(0, N, 2)]
    kept_values = [i for i in range(0, N, 2)]
    expected_batch = Batch(
        [None if i % 2 == 1 else i for i in range(N)],
        [(i,) for i in range(N)]
    ).remove_by_indices(set((i,) for i in range(1, N, 2)))
    expected = BatchModeSIMDStepInput(
        indices=kept_indices,
        parameters={"input1": expected_batch}
    )
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from typing import Any, Dict, List, Optional, Set, Tuple, Union

# imports
import pytest
from inference.core.workflows.execution_engine.v1.executor.execution_data_manager.step_input_assembler import \
    construct_simd_step_input


class ExecutionEngineRuntimeError(Exception):
    def __init__(self, public_message, context=None):
        super().__init__(public_message)
        self.public_message = public_message
        self.context = context

# Simulate Batch and BatchModeSIMDStepInput
class Batch:
    def __init__(self, data, indices):
        self.data = data
        self.indices = indices

    def iter_with_indices(self):
        return zip(self.indices, self.data)

    def remove_by_indices(self, indices_to_remove):
        # Remove elements by index from data/indices
        new_data = []
        new_indices = []
        for d, idx in zip(self.data, self.indices):
            if idx not in indices_to_remove:
                new_data.append(d)
                new_indices.append(idx)
        return Batch(new_data, new_indices)

    def __eq__(self, other):
        if not isinstance(other, Batch):
            return False
        return self.data == other.data and self.indices == other.indices

    def __repr__(self):
        return f"Batch(data={self.data}, indices={self.indices})"

# ParameterSpecification and NodeInputCategory stubs
class ParameterSpecification:
    def __init__(self, parameter_name, nested_element_key=None):
        self.parameter_name = parameter_name
        self.nested_element_key = nested_element_key

class NodeInputCategory:
    pass

# StepInputDefinition, CompoundStepInputDefinition, DynamicStepInputDefinition, StaticStepInputDefinition stubs
class StepInputDefinition:
    def __init__(self, parameter_specification, category, is_batch=False, points_to_input=True, points_to_step_output=False, value=None, selector=None, dimensionality=1, data_lineage=None):
        self.parameter_specification = parameter_specification
        self.category = category
        self._is_batch = is_batch
        self._points_to_input = points_to_input
        self._points_to_step_output = points_to_step_output
        self.value = value
        self.selector = selector or parameter_specification.parameter_name
        self._dimensionality = dimensionality
        self.data_lineage = data_lineage or tuple([0]*dimensionality)

    def is_compound_input(self):
        return False

    def get_dimensionality(self):
        return self._dimensionality

    def is_batch_oriented(self):
        return self._is_batch

    def points_to_input(self):
        return self._points_to_input

    def points_to_step_output(self):
        return self._points_to_step_output

class CompoundStepInputDefinition:
    def __init__(self, definitions, is_list=False):
        self._definitions = definitions
        self._is_list = is_list

    def is_compound_input(self):
        return True

    def iterate_through_definitions(self):
        return self._definitions

    def represents_list_of_inputs(self):
        return self._is_list

    def get_dimensionality(self):
        # Max dimensionality of nested definitions
        if not self._definitions:
            return 0
        return max([d.get_dimensionality() for d in self._definitions])

# StepNode stub
class StepNode:
    def __init__(self, name, input_data, step_manifest, step_execution_dimensionality=1, execution_branches_impacting_inputs=None, auto_batch_casting_lineage_supports=None):
        self.name = name
        self.input_data = input_data
        self.step_manifest = step_manifest
        self.step_execution_dimensionality = step_execution_dimensionality
        self.execution_branches_impacting_inputs = execution_branches_impacting_inputs or []
        self.auto_batch_casting_lineage_supports = auto_batch_casting_lineage_supports or {}

# WorkflowBlockManifest stub
class WorkflowBlockManifest:
    def __init__(self, accepts_batch=True, accepts_empty=False):
        self._accepts_batch = accepts_batch
        self._accepts_empty = accepts_empty

    def accepts_batch_input(self):
        return self._accepts_batch

    def accepts_empty_values(self):
        return self._accepts_empty

# ExecutionCache stub
class ExecutionCache:
    def __init__(self, batch_outputs=None, non_batch_outputs=None):
        self.batch_outputs = batch_outputs or {}
        self.non_batch_outputs = non_batch_outputs or {}

    def get_batch_output(self, selector, batch_elements_indices, mask=None):
        # Return list of outputs for the given indices, optionally mask out
        outputs = self.batch_outputs.get(selector, [])
        if mask is not None:
            return [out if idx in mask else None for out, idx in zip(outputs, batch_elements_indices)]
        return [outputs[i] for i in range(len(batch_elements_indices))]

    def get_non_batch_output(self, selector):
        return self.non_batch_outputs.get(selector, None)

# DynamicBatchesManager stub
class DynamicBatchesManager:
    def __init__(self, lineage_to_indices):
        self.lineage_to_indices = lineage_to_indices

    def get_indices_for_data_lineage(self, lineage):
        return self.lineage_to_indices.get(tuple(lineage), [])

# BranchingManager stub (already defined in prompt, but simplified here)
class BranchingManager:
    def __init__(self, masks):
        self._masks = masks
        self._batch_compatibility = {k: not isinstance(v, bool) for k, v in masks.items()}

    def get_mask(self, execution_branch):
        return self._masks[execution_branch]

    def is_execution_branch_batch_oriented(self, execution_branch):
        return self._batch_compatibility[execution_branch]

    def is_execution_branch_registered(self, execution_branch):
        return execution_branch in self._masks

# --- The function under test (copied from prompt) ---
# ... [construct_simd_step_input and all helpers as in the prompt] ...
# For brevity, we assume the code above is already included and available here.

# --- TEST SUITE ---

# 1. BASIC TEST CASES

def test_single_batch_input_simple():
    # Test: Single batch input, all indices present, no masks
    input_param = StepInputDefinition(ParameterSpecification("input1"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"input1": input_param},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    runtime_parameters = {"input1": ["a", "b", "c"]}
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): [0, 1, 2]})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 20.1μs -> 16.6μs (20.6% faster)

def test_multiple_batch_inputs_same_indices():
    # Test: Two batch inputs, same indices, no masks
    input1 = StepInputDefinition(ParameterSpecification("input1"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    input2 = StepInputDefinition(ParameterSpecification("input2"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"input1": input1, "input2": input2},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    runtime_parameters = {"input1": [1, 2], "input2": [3, 4]}
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): [0, 1]})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 24.4μs -> 19.4μs (25.6% faster)

def test_batch_and_scalar_input():
    # Test: One batch input, one scalar input
    input1 = StepInputDefinition(ParameterSpecification("input1"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    input2 = StepInputDefinition(ParameterSpecification("input2"), NodeInputCategory(), is_batch=False, points_to_input=True, dimensionality=0)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"input1": input1, "input2": input2},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    runtime_parameters = {"input1": [10, 20], "input2": 99}
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): [0, 1]})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 20.9μs -> 16.8μs (24.6% faster)

def test_compound_input_dict():
    # Test: Compound input as dict of batch inputs
    input1 = StepInputDefinition(ParameterSpecification("x"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    input2 = StepInputDefinition(ParameterSpecification("y"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    compound = CompoundStepInputDefinition([input1, input2], is_list=False)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"point": compound},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    runtime_parameters = {"x": [1, 2], "y": [3, 4]}
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): [0, 1]})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 25.5μs -> 21.7μs (17.9% faster)

def test_compound_input_list():
    # Test: Compound input as list of batch inputs
    input1 = StepInputDefinition(ParameterSpecification("a"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    input2 = StepInputDefinition(ParameterSpecification("b"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    compound = CompoundStepInputDefinition([input1, input2], is_list=True)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"listparam": compound},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    runtime_parameters = {"a": [7, 8], "b": [9, 10]}
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): [0, 1]})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 26.4μs -> 21.5μs (23.1% faster)

# 2. EDGE TEST CASES

def test_empty_batch_input():
    # Test: Empty batch input list
    input1 = StepInputDefinition(ParameterSpecification("input1"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"input1": input1},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    runtime_parameters = {"input1": []}
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): []})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 15.4μs -> 13.0μs (18.9% faster)




def test_accepts_empty_values_returns_empty():
    # Test: Step manifest accepts empty values, and batch input is empty
    input1 = StepInputDefinition(ParameterSpecification("input1"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    step_manifest = WorkflowBlockManifest(accepts_batch=True, accepts_empty=True)
    step_node = StepNode(
        name="test_step",
        input_data={"input1": input1},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    runtime_parameters = {"input1": []}
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): []})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 13.0μs -> 13.5μs (3.72% slower)

# 3. LARGE SCALE TEST CASES

def test_large_batch_input():
    # Test: Large batch input, check performance and correctness
    N = 1000
    input1 = StepInputDefinition(ParameterSpecification("input1"), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"input1": input1},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    runtime_parameters = {"input1": list(range(N))}
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): list(range(N))})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 171μs -> 83.2μs (107% faster)

def test_large_compound_input():
    # Test: Large compound input (dict of batch inputs)
    N = 500
    input_defs = []
    runtime_parameters = {}
    for i in range(5):
        pname = f"p{i}"
        input_defs.append(StepInputDefinition(ParameterSpecification(pname), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1))
        runtime_parameters[pname] = list(range(N))
    compound = CompoundStepInputDefinition(input_defs, is_list=False)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"comp": compound},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): list(range(N))})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 160μs -> 111μs (44.9% faster)
    for i in range(5):
        pname = f"p{i}"


def test_large_compound_list_input():
    # Test: Large compound input as list of batch inputs
    N = 200
    input_defs = []
    runtime_parameters = {}
    for i in range(10):
        pname = f"arr{i}"
        input_defs.append(StepInputDefinition(ParameterSpecification(pname), NodeInputCategory(), is_batch=True, points_to_input=True, dimensionality=1))
        runtime_parameters[pname] = list(range(N))
    compound = CompoundStepInputDefinition(input_defs, is_list=True)
    step_manifest = WorkflowBlockManifest(accepts_batch=True)
    step_node = StepNode(
        name="test_step",
        input_data={"arrs": compound},
        step_manifest=step_manifest,
        step_execution_dimensionality=1,
        execution_branches_impacting_inputs=[],
    )
    execution_cache = ExecutionCache()
    dynamic_batches_manager = DynamicBatchesManager({(0,): list(range(N))})
    branching_manager = BranchingManager({})
    codeflash_output = construct_simd_step_input(
        step_node, runtime_parameters, execution_cache, dynamic_batches_manager, branching_manager
    ); result = codeflash_output # 359μs -> 206μs (73.9% faster)
    for i in range(10):
        pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1504-2025-08-25T10.24.18 and push.

Codeflash

…`feature/try-to-beat-the-limitation-of-ee-in-terms-of-singular-elements-pushed-into-batch-inputs`)

The optimized code achieves a **36% speedup** through a single but impactful conditional check optimization in the `prepare_parameters` function.

**Key Optimization:**
The main performance improvement comes from adding an `if empty_indices:` check before executing expensive list comprehension and data removal operations:

```python
# Original: Always executes these expensive operations
indices = [e for e in indices if e not in empty_indices]
result = remove_indices(value=result, indices=empty_indices)

# Optimized: Only executes when empty_indices is non-empty
if empty_indices:
    indices = [e for e in indices if e not in empty_indices]
    result = remove_indices(value=result, indices=empty_indices)
```

**Why this optimization works:**
- In many test cases, `empty_indices` is an empty set, making the filtering operations unnecessary
- The list comprehension `[e for e in indices if e not in empty_indices]` has O(n*m) complexity where n=len(indices) and m=len(empty_indices)
- `remove_indices()` recursively processes nested data structures, which is expensive even for empty removal sets
- By avoiding these operations when `empty_indices` is empty, we eliminate significant computational overhead

**Performance impact by test case type:**
- **Large batch inputs** see the biggest gains (43-107% faster) because they avoid expensive O(n) operations on large datasets when no filtering is needed
- **Basic test cases** show consistent 15-25% improvements from avoiding unnecessary operations
- **Edge cases with actual empty elements** may see minimal or slightly negative impact (0.5% slower) due to the additional conditional check, but this is negligible compared to the gains in common cases

This optimization is particularly effective because most workflow executions don't have empty batch elements that need filtering, making the conditional check a highly beneficial guard against unnecessary work.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Aug 25, 2025
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Aug 25, 2025
@PawelPeczek-Roboflow PawelPeczek-Roboflow merged commit 37c120b into feature/try-to-beat-the-limitation-of-ee-in-terms-of-singular-elements-pushed-into-batch-inputs Aug 25, 2025
2 checks passed
@PawelPeczek-Roboflow PawelPeczek-Roboflow deleted the codeflash/optimize-pr1504-2025-08-25T10.24.18 branch August 25, 2025 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants