Skip to content

Commit a08c274

Browse files
authored
[tests] Use tmp_path fixture modular tests (#13194)
* add a test to check modular index consistency * check for compulsory keys. * use fixture for tmp_path in modular tests. * remove unneeded test. * fix code quality. * up * up
1 parent 7f92d81 commit a08c274

File tree

3 files changed

+52
-61
lines changed

3 files changed

+52
-61
lines changed

tests/modular_pipelines/flux/test_modular_pipeline_flux.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import random
17-
import tempfile
1817

1918
import numpy as np
2019
import PIL
@@ -129,18 +128,16 @@ def get_dummy_inputs(self, seed=0):
129128

130129
return inputs
131130

132-
def test_save_from_pretrained(self):
131+
def test_save_from_pretrained(self, tmp_path):
133132
pipes = []
134133
base_pipe = self.get_pipeline().to(torch_device)
135134
pipes.append(base_pipe)
136135

137-
with tempfile.TemporaryDirectory() as tmpdirname:
138-
base_pipe.save_pretrained(tmpdirname)
139-
140-
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
141-
pipe.load_components(torch_dtype=torch.float32)
142-
pipe.to(torch_device)
143-
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
136+
base_pipe.save_pretrained(str(tmp_path))
137+
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
138+
pipe.load_components(torch_dtype=torch.float32)
139+
pipe.to(torch_device)
140+
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
144141

145142
pipes.append(pipe)
146143

@@ -212,18 +209,16 @@ def get_dummy_inputs(self, seed=0):
212209

213210
return inputs
214211

215-
def test_save_from_pretrained(self):
212+
def test_save_from_pretrained(self, tmp_path):
216213
pipes = []
217214
base_pipe = self.get_pipeline().to(torch_device)
218215
pipes.append(base_pipe)
219216

220-
with tempfile.TemporaryDirectory() as tmpdirname:
221-
base_pipe.save_pretrained(tmpdirname)
222-
223-
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
224-
pipe.load_components(torch_dtype=torch.float32)
225-
pipe.to(torch_device)
226-
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
217+
base_pipe.save_pretrained(str(tmp_path))
218+
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
219+
pipe.load_components(torch_dtype=torch.float32)
220+
pipe.to(torch_device)
221+
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
227222

228223
pipes.append(pipe)
229224

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import gc
22
import json
33
import os
4-
import tempfile
54
from typing import Callable
65

76
import pytest
@@ -341,16 +340,15 @@ def test_components_auto_cpu_offload_inference_consistent(self):
341340

342341
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
343342

344-
def test_save_from_pretrained(self):
343+
def test_save_from_pretrained(self, tmp_path):
345344
pipes = []
346345
base_pipe = self.get_pipeline().to(torch_device)
347346
pipes.append(base_pipe)
348347

349-
with tempfile.TemporaryDirectory() as tmpdirname:
350-
base_pipe.save_pretrained(tmpdirname)
351-
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
352-
pipe.load_components(torch_dtype=torch.float32)
353-
pipe.to(torch_device)
348+
base_pipe.save_pretrained(str(tmp_path))
349+
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
350+
pipe.load_components(torch_dtype=torch.float32)
351+
pipe.to(torch_device)
354352

355353
pipes.append(pipe)
356354

@@ -362,32 +360,31 @@ def test_save_from_pretrained(self):
362360

363361
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
364362

365-
def test_modular_index_consistency(self):
363+
def test_modular_index_consistency(self, tmp_path):
366364
pipe = self.get_pipeline()
367365
components_spec = pipe._component_specs
368366
components = sorted(components_spec.keys())
369367

370-
with tempfile.TemporaryDirectory() as tmpdir:
371-
pipe.save_pretrained(tmpdir)
372-
index_file = os.path.join(tmpdir, "modular_model_index.json")
373-
assert os.path.exists(index_file)
368+
pipe.save_pretrained(str(tmp_path))
369+
index_file = tmp_path / "modular_model_index.json"
370+
assert index_file.exists()
374371

375-
with open(index_file) as f:
376-
index_contents = json.load(f)
372+
with open(index_file) as f:
373+
index_contents = json.load(f)
377374

378-
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
379-
for k in compulsory_keys:
380-
assert k in index_contents
375+
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
376+
for k in compulsory_keys:
377+
assert k in index_contents
381378

382-
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
383-
for component in components:
384-
spec = components_spec[component]
385-
for attr in to_check_attrs:
386-
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
387-
for attr in to_check_attrs:
388-
assert component in index_contents, f"{component} should be present in index but isn't."
389-
attr_value_from_index = index_contents[component][2][attr]
390-
assert getattr(spec, attr) == attr_value_from_index
379+
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
380+
for component in components:
381+
spec = components_spec[component]
382+
for attr in to_check_attrs:
383+
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
384+
for attr in to_check_attrs:
385+
assert component in index_contents, f"{component} should be present in index but isn't."
386+
attr_value_from_index = index_contents[component][2][attr]
387+
assert getattr(spec, attr) == attr_value_from_index
391388

392389
def test_workflow_map(self):
393390
blocks = self.pipeline_blocks_class()
@@ -483,7 +480,7 @@ class DummyBlockTwo:
483480

484481
def test_sequential_block_requirements_save_load(self, tmp_path):
485482
pipe = self.get_dummy_block_pipe()
486-
pipe.save_pretrained(tmp_path)
483+
pipe.save_pretrained(str(tmp_path))
487484

488485
config_path = tmp_path / "modular_config.json"
489486

@@ -508,7 +505,7 @@ def test_sequential_block_requirements_warnings(self, tmp_path):
508505
logger.setLevel(30)
509506

510507
with CaptureLogger(logger) as cap_logger:
511-
pipe.save_pretrained(tmp_path)
508+
pipe.save_pretrained(str(tmp_path))
512509

513510
template = "{req} was specified in the requirements but wasn't found in the current environment"
514511
msg_xyz = template.format(req="xyz")
@@ -518,7 +515,7 @@ def test_sequential_block_requirements_warnings(self, tmp_path):
518515

519516
def test_conditional_block_requirements_save_load(self, tmp_path):
520517
pipe = self.get_dummy_conditional_block_pipe()
521-
pipe.save_pretrained(tmp_path)
518+
pipe.save_pretrained(str(tmp_path))
522519

523520
config_path = tmp_path / "modular_config.json"
524521
with open(config_path, "r") as f:
@@ -535,7 +532,7 @@ def test_conditional_block_requirements_save_load(self, tmp_path):
535532

536533
def test_loop_block_requirements_save_load(self, tmp_path):
537534
pipe = self.get_dummy_loop_block_pipe()
538-
pipe.save_pretrained(tmp_path)
535+
pipe.save_pretrained(str(tmp_path))
539536

540537
config_path = tmp_path / "modular_config.json"
541538
with open(config_path, "r") as f:

tests/modular_pipelines/test_modular_pipelines_custom_blocks.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,25 +153,24 @@ def test_custom_block_output(self):
153153
output_prompt = output.values["output_prompt"]
154154
assert output_prompt.startswith("Modular diffusers + ")
155155

156-
def test_custom_block_saving_loading(self):
156+
def test_custom_block_saving_loading(self, tmp_path):
157157
custom_block = DummyCustomBlockSimple()
158158

159-
with tempfile.TemporaryDirectory() as tmpdir:
160-
custom_block.save_pretrained(tmpdir)
161-
assert any("modular_config.json" in k for k in os.listdir(tmpdir))
159+
custom_block.save_pretrained(tmp_path)
160+
assert any("modular_config.json" in k for k in os.listdir(tmp_path))
162161

163-
with open(os.path.join(tmpdir, "modular_config.json"), "r") as f:
164-
config = json.load(f)
165-
auto_map = config["auto_map"]
166-
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
162+
with open(os.path.join(tmp_path, "modular_config.json"), "r") as f:
163+
config = json.load(f)
164+
auto_map = config["auto_map"]
165+
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
167166

168-
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
169-
# This is why, we have to separately save the Python script here.
170-
code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py")
171-
with open(code_path, "w") as f:
172-
f.write(CODE_STR)
167+
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
168+
# This is why, we have to separately save the Python script here.
169+
code_path = os.path.join(tmp_path, "test_modular_pipelines_custom_blocks.py")
170+
with open(code_path, "w") as f:
171+
f.write(CODE_STR)
173172

174-
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True)
173+
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmp_path, trust_remote_code=True)
175174

176175
pipe = loaded_custom_block.init_pipeline()
177176
prompt = "Diffusers is nice"

0 commit comments

Comments
 (0)