Skip to content

Commit 58735eb

Browse files
committed
fix: introduce pydantic v1/v2 code to hanble v1 dataclasses
1 parent 04112ae commit 58735eb

File tree

1 file changed

+41
-21
lines changed

1 file changed

+41
-21
lines changed

polyfactory/factories/pydantic_factory.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,15 @@
6161
# is installed.
6262
from pydantic import PyObject
6363

64-
# prevent unbound variable warnings
64+
# Prevent unbound variable warnings
6565
BaseModelV2 = BaseModelV1
6666
UndefinedV2 = Undefined
67+
68+
if TYPE_CHECKING:
69+
from pydantic.dataclasses import Dataclass as PydanticDataclassV1 # pyright: ignore[reportPrivateImportUsage]
70+
71+
# Prevent unbound variable warnings
72+
PydanticDataclassV2 = PydanticDataclassV1
6773
except ImportError:
6874
# pydantic v2
6975

@@ -92,6 +98,8 @@
9298
from pydantic.v1.color import Color # type: ignore[assignment]
9399
from pydantic.v1.fields import DeferredType, ModelField, Undefined
94100

101+
if TYPE_CHECKING:
102+
from pydantic.dataclasses import PydanticDataclass as PydanticDataclassV2 # pyright: ignore[reportPrivateImportUsage]
95103

96104
if TYPE_CHECKING:
97105
from collections import abc
@@ -100,7 +108,6 @@
100108

101109
from typing_extensions import NotRequired, TypeGuard
102110

103-
from pydantic.dataclasses import PydanticDataclass # pyright: ignore[reportPrivateImportUsage]
104111

105112
ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]
106113
T = TypeVar("T")
@@ -635,8 +642,11 @@ def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ign
635642
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)
636643

637644

638-
def is_pydantic_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclass]:
639-
# This method is available in the `pydantic.dataclasses` module for python >= 3.9
645+
def _is_pydantic_v1_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV1]:
646+
return is_dataclass(cls) and "__pydantic_model__" in cls.__dict__
647+
648+
649+
def _is_pydantic_v2_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV2]:
640650
return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__
641651

642652

@@ -647,27 +657,37 @@ class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var]
647657

648658
@classmethod
649659
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
650-
return is_pydantic_dataclass(value)
660+
return _is_pydantic_v1_dataclass(value) or _is_pydantic_v2_dataclass(value)
651661

652662
@classmethod
653663
def get_model_fields(cls) -> list[FieldMeta]:
654-
if not is_pydantic_dataclass(cls.__model__):
664+
if _is_pydantic_v1_dataclass(cls.__model__):
665+
pydantic_model = cls.__model__.__pydantic_model__
666+
cls._fields_metadata = [
667+
PydanticFieldMeta.from_model_field(
668+
field,
669+
use_alias=not pydantic_model.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
670+
random=cls.__random__,
671+
)
672+
for field in pydantic_model.__fields__.values()
673+
]
674+
elif _is_pydantic_v2_dataclass(cls.__model__):
675+
pydantic_fields = cls.__model__.__pydantic_fields__
676+
pydantic_config = cls.__model__.__pydantic_config__
677+
cls._fields_metadata = [
678+
PydanticFieldMeta.from_field_info(
679+
field_info=field_info,
680+
field_name=field_name,
681+
random=cls.__random__,
682+
use_alias=not pydantic_config.get(
683+
"populate_by_name",
684+
False,
685+
),
686+
)
687+
for field_name, field_info in pydantic_fields.items()
688+
]
689+
else:
655690
# This should be unreachable
656691
return []
657692

658-
pydantic_fields = cls.__model__.__pydantic_fields__
659-
pydantic_config = cls.__model__.__pydantic_config__
660-
cls._fields_metadata = [
661-
PydanticFieldMeta.from_field_info(
662-
field_info=field_info,
663-
field_name=field_name,
664-
random=cls.__random__,
665-
use_alias=not pydantic_config.get(
666-
"populate_by_name",
667-
False,
668-
),
669-
)
670-
for field_name, field_info in pydantic_fields.items()
671-
]
672-
673693
return cls._fields_metadata

0 commit comments

Comments
 (0)