Skip to content

Commit 1b8fc6c

Browse files
authored
[modular] change the template modular pipeline card (#13072)
* start better template for modular pipeline card. * simplify structure. * refine. * style. * up * add tests
1 parent 6d4fc6b commit 1b8fc6c

File tree

4 files changed

+467
-6
lines changed

4 files changed

+467
-6
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
3535
from .components_manager import ComponentsManager
3636
from .modular_pipeline_utils import (
37+
MODULAR_MODEL_CARD_TEMPLATE,
3738
ComponentSpec,
3839
ConfigSpec,
3940
InputParam,
4041
InsertableDict,
4142
OutputParam,
4243
format_components,
4344
format_configs,
45+
generate_modular_model_card_content,
4446
make_doc_string,
4547
)
4648

@@ -1753,9 +1755,19 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
17531755
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
17541756
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
17551757

1758+
# Generate modular pipeline card content
1759+
card_content = generate_modular_model_card_content(self.blocks)
1760+
17561761
# Create a new empty model card and eventually tag it
1757-
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
1758-
model_card = populate_model_card(model_card)
1762+
model_card = load_or_create_model_card(
1763+
repo_id,
1764+
token=token,
1765+
is_pipeline=True,
1766+
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
1767+
is_modular=True,
1768+
)
1769+
model_card = populate_model_card(model_card, tags=card_content["tags"])
1770+
17591771
model_card.save(os.path.join(save_directory, "README.md"))
17601772

17611773
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,30 @@
3131

3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3333

34+
# Template for modular pipeline model card description with placeholders
35+
MODULAR_MODEL_CARD_TEMPLATE = """{model_description}
36+
37+
## Example Usage
38+
39+
[TODO]
40+
41+
## Pipeline Architecture
42+
43+
This modular pipeline is composed of the following blocks:
44+
45+
{blocks_description} {trigger_inputs_section}
46+
47+
## Model Components
48+
49+
{components_description} {configs_section}
50+
51+
## Input/Output Specification
52+
53+
### Inputs {inputs_description}
54+
55+
### Outputs {outputs_description}
56+
"""
57+
3458

3559
class InsertableDict(OrderedDict):
3660
def insert(self, key, value, index):
@@ -916,3 +940,178 @@ def make_doc_string(
916940
output += format_output_params(outputs, indent_level=2)
917941

918942
return output
943+
944+
945+
def generate_modular_model_card_content(blocks) -> Dict[str, Any]:
946+
"""
947+
Generate model card content for a modular pipeline.
948+
949+
This function creates a comprehensive model card with descriptions of the pipeline's architecture, components,
950+
configurations, inputs, and outputs.
951+
952+
Args:
953+
blocks: The pipeline's blocks object containing all pipeline specifications
954+
955+
Returns:
956+
Dict[str, Any]: A dictionary containing formatted content sections:
957+
- pipeline_name: Name of the pipeline
958+
- model_description: Overall description with pipeline type
959+
- blocks_description: Detailed architecture of blocks
960+
- components_description: List of required components
961+
- configs_section: Configuration parameters section
962+
- inputs_description: Input parameters specification
963+
- outputs_description: Output parameters specification
964+
- trigger_inputs_section: Conditional execution information
965+
- tags: List of relevant tags for the model card
966+
"""
967+
blocks_class_name = blocks.__class__.__name__
968+
pipeline_name = blocks_class_name.replace("Blocks", " Pipeline")
969+
description = getattr(blocks, "description", "A modular diffusion pipeline.")
970+
971+
# generate blocks architecture description
972+
blocks_desc_parts = []
973+
sub_blocks = getattr(blocks, "sub_blocks", None) or {}
974+
if sub_blocks:
975+
for i, (name, block) in enumerate(sub_blocks.items()):
976+
block_class = block.__class__.__name__
977+
block_desc = block.description.split("\n")[0] if getattr(block, "description", "") else ""
978+
blocks_desc_parts.append(f"{i + 1}. **{name}** (`{block_class}`)")
979+
if block_desc:
980+
blocks_desc_parts.append(f" - {block_desc}")
981+
982+
# add sub-blocks if any
983+
if hasattr(block, "sub_blocks") and block.sub_blocks:
984+
for sub_name, sub_block in block.sub_blocks.items():
985+
sub_class = sub_block.__class__.__name__
986+
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
987+
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
988+
if sub_desc:
989+
blocks_desc_parts.append(f" - {sub_desc}")
990+
991+
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
992+
993+
components = getattr(blocks, "expected_components", [])
994+
if components:
995+
components_str = format_components(components, indent_level=0, add_empty_lines=False)
996+
# remove the "Components:" header since template has its own
997+
components_description = components_str.replace("Components:\n", "").strip()
998+
if components_description:
999+
# Convert to enumerated list
1000+
lines = [line.strip() for line in components_description.split("\n") if line.strip()]
1001+
enumerated_lines = [f"{i + 1}. {line}" for i, line in enumerate(lines)]
1002+
components_description = "\n".join(enumerated_lines)
1003+
else:
1004+
components_description = "No specific components required."
1005+
else:
1006+
components_description = "No specific components required. Components can be loaded dynamically."
1007+
1008+
configs = getattr(blocks, "expected_configs", [])
1009+
configs_section = ""
1010+
if configs:
1011+
configs_str = format_configs(configs, indent_level=0, add_empty_lines=False)
1012+
configs_description = configs_str.replace("Configs:\n", "").strip()
1013+
if configs_description:
1014+
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
1015+
1016+
inputs = blocks.inputs
1017+
outputs = blocks.outputs
1018+
1019+
# format inputs as markdown list
1020+
inputs_parts = []
1021+
required_inputs = [inp for inp in inputs if inp.required]
1022+
optional_inputs = [inp for inp in inputs if not inp.required]
1023+
1024+
if required_inputs:
1025+
inputs_parts.append("**Required:**\n")
1026+
for inp in required_inputs:
1027+
if hasattr(inp.type_hint, "__name__"):
1028+
type_str = inp.type_hint.__name__
1029+
elif inp.type_hint is not None:
1030+
type_str = str(inp.type_hint).replace("typing.", "")
1031+
else:
1032+
type_str = "Any"
1033+
desc = inp.description or "No description provided"
1034+
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
1035+
1036+
if optional_inputs:
1037+
if required_inputs:
1038+
inputs_parts.append("")
1039+
inputs_parts.append("**Optional:**\n")
1040+
for inp in optional_inputs:
1041+
if hasattr(inp.type_hint, "__name__"):
1042+
type_str = inp.type_hint.__name__
1043+
elif inp.type_hint is not None:
1044+
type_str = str(inp.type_hint).replace("typing.", "")
1045+
else:
1046+
type_str = "Any"
1047+
desc = inp.description or "No description provided"
1048+
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
1049+
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
1050+
1051+
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
1052+
1053+
# format outputs as markdown list
1054+
outputs_parts = []
1055+
for out in outputs:
1056+
if hasattr(out.type_hint, "__name__"):
1057+
type_str = out.type_hint.__name__
1058+
elif out.type_hint is not None:
1059+
type_str = str(out.type_hint).replace("typing.", "")
1060+
else:
1061+
type_str = "Any"
1062+
desc = out.description or "No description provided"
1063+
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
1064+
1065+
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
1066+
1067+
trigger_inputs_section = ""
1068+
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
1069+
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
1070+
if trigger_inputs_list:
1071+
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
1072+
trigger_inputs_section = f"""
1073+
### Conditional Execution
1074+
1075+
This pipeline contains blocks that are selected at runtime based on inputs:
1076+
- **Trigger Inputs**: {trigger_inputs_str}
1077+
"""
1078+
1079+
# generate tags based on pipeline characteristics
1080+
tags = ["modular-diffusers", "diffusers"]
1081+
1082+
if hasattr(blocks, "model_name") and blocks.model_name:
1083+
tags.append(blocks.model_name)
1084+
1085+
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
1086+
triggers = blocks.trigger_inputs
1087+
if any(t in triggers for t in ["mask", "mask_image"]):
1088+
tags.append("inpainting")
1089+
if any(t in triggers for t in ["image", "image_latents"]):
1090+
tags.append("image-to-image")
1091+
if any(t in triggers for t in ["control_image", "controlnet_cond"]):
1092+
tags.append("controlnet")
1093+
if not any(t in triggers for t in ["image", "mask", "image_latents", "mask_image"]):
1094+
tags.append("text-to-image")
1095+
else:
1096+
tags.append("text-to-image")
1097+
1098+
block_count = len(blocks.sub_blocks)
1099+
model_description = f"""This is a modular diffusion pipeline built with 🧨 Diffusers' modular pipeline framework.
1100+
1101+
**Pipeline Type**: {blocks_class_name}
1102+
1103+
**Description**: {description}
1104+
1105+
This pipeline uses a {block_count}-block architecture that can be customized and extended."""
1106+
1107+
return {
1108+
"pipeline_name": pipeline_name,
1109+
"model_description": model_description,
1110+
"blocks_description": blocks_description,
1111+
"components_description": components_description,
1112+
"configs_section": configs_section,
1113+
"inputs_description": inputs_description,
1114+
"outputs_description": outputs_description,
1115+
"trigger_inputs_section": trigger_inputs_section,
1116+
"tags": tags,
1117+
}

src/diffusers/utils/hub_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def load_or_create_model_card(
107107
license: Optional[str] = None,
108108
widget: Optional[List[dict]] = None,
109109
inference: Optional[bool] = None,
110+
is_modular: bool = False,
110111
) -> ModelCard:
111112
"""
112113
Loads or creates a model card.
@@ -131,6 +132,8 @@ def load_or_create_model_card(
131132
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
132133
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
133134
`load_or_create_model_card` from a training script.
135+
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
136+
When True, uses model_description as-is without additional template formatting.
134137
"""
135138
if not is_jinja_available():
136139
raise ValueError(
@@ -159,10 +162,14 @@ def load_or_create_model_card(
159162
)
160163
else:
161164
card_data = ModelCardData()
162-
component = "pipeline" if is_pipeline else "model"
163-
if model_description is None:
164-
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
165-
model_card = ModelCard.from_template(card_data, model_description=model_description)
165+
if is_modular and model_description is not None:
166+
model_card = ModelCard(model_description)
167+
model_card.data = card_data
168+
else:
169+
component = "pipeline" if is_pipeline else "model"
170+
if model_description is None:
171+
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
172+
model_card = ModelCard.from_template(card_data, model_description=model_description)
166173

167174
return model_card
168175

0 commit comments

Comments
 (0)