Skip to content

Commit d190959

Browse files
Make sure DDPM and diffusers can be used without Transformers (#5668)
* fix: import bug * fix * fix * fix import utils for lcm * fix: pixart alpha init * Fix --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent d5ff8f8 commit d190959

File tree

3 files changed

+82
-11
lines changed

3 files changed

+82
-11
lines changed

src/diffusers/loaders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2390,7 +2390,7 @@ def unfuse_text_encoder_lora(text_encoder):
23902390
def set_adapters_for_text_encoder(
23912391
self,
23922392
adapter_names: Union[List[str], str],
2393-
text_encoder: Optional[PreTrainedModel] = None,
2393+
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
23942394
text_encoder_weights: List[float] = None,
23952395
):
23962396
"""
@@ -2429,7 +2429,7 @@ def process_weights(adapter_names, weights):
24292429
)
24302430
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
24312431

2432-
def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
2432+
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
24332433
"""
24342434
Disables the LoRA layers for the text encoder.
24352435
@@ -2446,7 +2446,7 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel]
24462446
raise ValueError("Text Encoder not found.")
24472447
set_adapter_layers(text_encoder, enabled=False)
24482448

2449-
def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
2449+
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
24502450
"""
24512451
Enables the LoRA layers for the text encoder.
24522452
Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,40 @@
11
from typing import TYPE_CHECKING
22

33
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
46
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
510
)
611

712

8-
_import_structure = {
9-
"pipeline_latent_consistency_img2img": ["LatentConsistencyModelImg2ImgPipeline"],
10-
"pipeline_latent_consistency_text2img": ["LatentConsistencyModelPipeline"],
11-
}
13+
_dummy_objects = {}
14+
_import_structure = {}
1215

1316

14-
if TYPE_CHECKING:
15-
from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline
16-
from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
17+
try:
18+
if not (is_transformers_available() and is_torch_available()):
19+
raise OptionalDependencyNotAvailable()
20+
except OptionalDependencyNotAvailable:
21+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
22+
23+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24+
else:
25+
_import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"]
26+
_import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"]
27+
28+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
29+
try:
30+
if not (is_transformers_available() and is_torch_available()):
31+
raise OptionalDependencyNotAvailable()
32+
33+
except OptionalDependencyNotAvailable:
34+
from ...utils.dummy_torch_and_transformers_objects import *
35+
else:
36+
from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline
37+
from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
1738

1839
else:
1940
import sys
@@ -24,3 +45,6 @@
2445
_import_structure,
2546
module_spec=__spec__,
2647
)
48+
49+
for name, value in _dummy_objects.items():
50+
setattr(sys.modules[__name__], name, value)
Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,48 @@
1-
from .pipeline_pixart_alpha import PixArtAlphaPipeline
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
17+
try:
18+
if not (is_transformers_available() and is_torch_available()):
19+
raise OptionalDependencyNotAvailable()
20+
except OptionalDependencyNotAvailable:
21+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
22+
23+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24+
else:
25+
_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
26+
27+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28+
try:
29+
if not (is_transformers_available() and is_torch_available()):
30+
raise OptionalDependencyNotAvailable()
31+
32+
except OptionalDependencyNotAvailable:
33+
from ...utils.dummy_torch_and_transformers_objects import *
34+
else:
35+
from .pipeline_pixart_alpha import PixArtAlphaPipeline
36+
37+
else:
38+
import sys
39+
40+
sys.modules[__name__] = _LazyModule(
41+
__name__,
42+
globals()["__file__"],
43+
_import_structure,
44+
module_spec=__spec__,
45+
)
46+
47+
for name, value in _dummy_objects.items():
48+
setattr(sys.modules[__name__], name, value)

0 commit comments

Comments
 (0)