Skip to content
Merged
181 changes: 65 additions & 116 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
</Tip>
"""

config_name = "config.json"
config_name = "modular_config.json"
Copy link
Collaborator Author

@DN6 DN6 Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with modular_model_index.json. Also could be cases where a repo contains model weights/config file and a modular pipeline block to load the model. We can avoid conflicts with the configs this way.

model_name = None

@classmethod
Expand All @@ -342,6 +342,16 @@ def expected_components(self) -> List[ComponentSpec]:
def expected_configs(self) -> List[ConfigSpec]:
return []

@property
def intermediate_inputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []

@property
def intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -423,6 +433,60 @@ def init_pipeline(
)
return modular_pipeline

def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}
state_inputs = self.inputs + self.intermediate_inputs

# Check inputs
for input_param in state_inputs:
if input_param.name:
value = state.get_input(input_param.name) or state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value

elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs(
input_param.kwargs_type
)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v

return BlockState(**data)

def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.set_intermediate(output_param.name, param, output_param.kwargs_type)

for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(param_name, param, input_param.kwargs_type)

@staticmethod
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
"""
Expand Down Expand Up @@ -652,51 +716,6 @@ def doc(self):
expected_configs=self.expected_configs,
)

# YiYi TODO: input and inteermediate inputs with same name? should warn?
def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}

# Check inputs
for input_param in self.inputs:
if input_param.name:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v

# Check intermediates
for input_param in self.intermediate_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
if intermediate_kwargs:
for k, v in intermediate_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)

def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
Expand Down Expand Up @@ -1633,75 +1652,6 @@ def loop_step(self, components, state: PipelineState, **kwargs):
def __call__(self, components, state: PipelineState) -> PipelineState:
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")

def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}

# Check inputs
for input_param in self.inputs:
if input_param.name:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v

# Check intermediates
for input_param in self.intermediate_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
if intermediate_kwargs:
for k, v in intermediate_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)

def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.set_intermediate(output_param.name, param, output_param.kwargs_type)

for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(param_name, param, input_param.kwargs_type)

@property
def doc(self):
return make_doc_string(
Expand Down Expand Up @@ -1974,7 +1924,6 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =

# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs

intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
Expand Down
Loading