11import gc
22import json
33import os
4- import tempfile
54from typing import Callable
65
76import pytest
@@ -341,16 +340,15 @@ def test_components_auto_cpu_offload_inference_consistent(self):
341340
342341 assert torch .abs (image_slices [0 ] - image_slices [1 ]).max () < 1e-3
343342
344- def test_save_from_pretrained (self ):
343+ def test_save_from_pretrained (self , tmp_path ):
345344 pipes = []
346345 base_pipe = self .get_pipeline ().to (torch_device )
347346 pipes .append (base_pipe )
348347
349- with tempfile .TemporaryDirectory () as tmpdirname :
350- base_pipe .save_pretrained (tmpdirname )
351- pipe = ModularPipeline .from_pretrained (tmpdirname ).to (torch_device )
352- pipe .load_components (torch_dtype = torch .float32 )
353- pipe .to (torch_device )
348+ base_pipe .save_pretrained (str (tmp_path ))
349+ pipe = ModularPipeline .from_pretrained (tmp_path ).to (torch_device )
350+ pipe .load_components (torch_dtype = torch .float32 )
351+ pipe .to (torch_device )
354352
355353 pipes .append (pipe )
356354
@@ -362,32 +360,31 @@ def test_save_from_pretrained(self):
362360
363361 assert torch .abs (image_slices [0 ] - image_slices [1 ]).max () < 1e-3
364362
365- def test_modular_index_consistency (self ):
363+ def test_modular_index_consistency (self , tmp_path ):
366364 pipe = self .get_pipeline ()
367365 components_spec = pipe ._component_specs
368366 components = sorted (components_spec .keys ())
369367
370- with tempfile .TemporaryDirectory () as tmpdir :
371- pipe .save_pretrained (tmpdir )
372- index_file = os .path .join (tmpdir , "modular_model_index.json" )
373- assert os .path .exists (index_file )
368+ pipe .save_pretrained (str (tmp_path ))
369+ index_file = tmp_path / "modular_model_index.json"
370+ assert index_file .exists ()
374371
375- with open (index_file ) as f :
376- index_contents = json .load (f )
372+ with open (index_file ) as f :
373+ index_contents = json .load (f )
377374
378- compulsory_keys = {"_blocks_class_name" , "_class_name" , "_diffusers_version" }
379- for k in compulsory_keys :
380- assert k in index_contents
375+ compulsory_keys = {"_blocks_class_name" , "_class_name" , "_diffusers_version" }
376+ for k in compulsory_keys :
377+ assert k in index_contents
381378
382- to_check_attrs = {"pretrained_model_name_or_path" , "revision" , "subfolder" }
383- for component in components :
384- spec = components_spec [component ]
385- for attr in to_check_attrs :
386- if getattr (spec , "pretrained_model_name_or_path" , None ) is not None :
387- for attr in to_check_attrs :
388- assert component in index_contents , f"{ component } should be present in index but isn't."
389- attr_value_from_index = index_contents [component ][2 ][attr ]
390- assert getattr (spec , attr ) == attr_value_from_index
379+ to_check_attrs = {"pretrained_model_name_or_path" , "revision" , "subfolder" }
380+ for component in components :
381+ spec = components_spec [component ]
382+ for attr in to_check_attrs :
383+ if getattr (spec , "pretrained_model_name_or_path" , None ) is not None :
384+ for attr in to_check_attrs :
385+ assert component in index_contents , f"{ component } should be present in index but isn't."
386+ attr_value_from_index = index_contents [component ][2 ][attr ]
387+ assert getattr (spec , attr ) == attr_value_from_index
391388
392389 def test_workflow_map (self ):
393390 blocks = self .pipeline_blocks_class ()
@@ -483,7 +480,7 @@ class DummyBlockTwo:
483480
484481 def test_sequential_block_requirements_save_load (self , tmp_path ):
485482 pipe = self .get_dummy_block_pipe ()
486- pipe .save_pretrained (tmp_path )
483+ pipe .save_pretrained (str ( tmp_path ) )
487484
488485 config_path = tmp_path / "modular_config.json"
489486
@@ -508,7 +505,7 @@ def test_sequential_block_requirements_warnings(self, tmp_path):
508505 logger .setLevel (30 )
509506
510507 with CaptureLogger (logger ) as cap_logger :
511- pipe .save_pretrained (tmp_path )
508+ pipe .save_pretrained (str ( tmp_path ) )
512509
513510 template = "{req} was specified in the requirements but wasn't found in the current environment"
514511 msg_xyz = template .format (req = "xyz" )
@@ -518,7 +515,7 @@ def test_sequential_block_requirements_warnings(self, tmp_path):
518515
519516 def test_conditional_block_requirements_save_load (self , tmp_path ):
520517 pipe = self .get_dummy_conditional_block_pipe ()
521- pipe .save_pretrained (tmp_path )
518+ pipe .save_pretrained (str ( tmp_path ) )
522519
523520 config_path = tmp_path / "modular_config.json"
524521 with open (config_path , "r" ) as f :
@@ -535,7 +532,7 @@ def test_conditional_block_requirements_save_load(self, tmp_path):
535532
536533 def test_loop_block_requirements_save_load (self , tmp_path ):
537534 pipe = self .get_dummy_loop_block_pipe ()
538- pipe .save_pretrained (tmp_path )
535+ pipe .save_pretrained (str ( tmp_path ) )
539536
540537 config_path = tmp_path / "modular_config.json"
541538 with open (config_path , "r" ) as f :
0 commit comments