From d689ce0aca962736a37979ae828588f0deed0580 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 01:19:46 +0000 Subject: [PATCH 01/37] feat(type-coverage-gen): Initial implementation of type coverage generation --- .pre-commit-config.yaml | 2 +- polyfactory/factories/base.py | 171 +++++++++++++-- polyfactory/utils/helpers.py | 30 ++- polyfactory/utils/model_coverage.py | 114 ++++++++++ polyfactory/value_generators/complex_types.py | 66 +++++- tests/test_type_coverage_generation.py | 194 ++++++++++++++++++ 6 files changed, 539 insertions(+), 38 deletions(-) create mode 100644 polyfactory/utils/model_coverage.py create mode 100644 tests/test_type_coverage_generation.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99cf36fa..71f6d84a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: "3.11" + python: "3.10" repos: - repo: https://github.com/compilerla/conventional-pre-commit rev: v2.4.0 diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index e7c5efd5..87c0db6f 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -22,6 +22,7 @@ from os.path import realpath from pathlib import Path from random import Random +from types import NoneType from typing import ( TYPE_CHECKING, Any, @@ -46,22 +47,12 @@ MIN_COLLECTION_LENGTH, RANDOMIZE_COLLECTION_LENGTH, ) -from polyfactory.exceptions import ( - ConfigurationException, - MissingBuildKwargException, - ParameterException, -) +from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use -from polyfactory.utils.helpers import unwrap_annotation, unwrap_args, unwrap_optional -from polyfactory.utils.predicates import ( - get_type_origin, - is_any, - is_literal, - is_optional, - is_safe_subclass, - is_union, -) -from polyfactory.value_generators.complex_types import handle_collection_type +from polyfactory.utils.helpers import flatten_annotation, unwrap_annotation, unwrap_args, unwrap_optional +from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage +from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union +from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage from polyfactory.value_generators.constrained_collections import ( handle_constrained_collection, handle_constrained_mapping, @@ -76,11 +67,7 @@ from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes from polyfactory.value_generators.constrained_url import handle_constrained_url from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid -from polyfactory.value_generators.primitives import ( - create_random_boolean, - create_random_bytes, - create_random_string, -) +from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string if TYPE_CHECKING: from typing_extensions import TypeGuard @@ -330,6 +317,32 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N return field_value() if callable(field_value) else field_value + @classmethod + def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any: + """Handle a value defined on the factory class itself. + + :param field_value: A value defined as an attribute on the factory class. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + + :returns: An arbitrary value correlating with the given field_meta value. + """ + if is_safe_subclass(field_value, BaseFactory): + if isinstance(field_build_parameters, Mapping): + return CoverageContainer(field_value.coverage(**field_build_parameters)) + + if isinstance(field_build_parameters, Sequence): + return [CoverageContainer(field_value.coverage(**parameter)) for parameter in field_build_parameters] + + return CoverageContainer(field_value.coverage()) + + if isinstance(field_value, Use): + return field_value.to_value() + + if isinstance(field_value, Fixture): + return CoverageContainerCallable(field_value.to_value) + + return CoverageContainerCallable(field_value) if callable(field_value) else field_value + @classmethod def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]: """Get a factory from registered factories or generate a factory dynamically. @@ -692,6 +705,67 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 msg, ) + @classmethod + def get_field_value_coverage( # noqa: C901 + cls, + field_meta: FieldMeta, + field_build_parameters: Any | None = None, + ) -> abc.Iterable[Any]: + """Return a field value on the subclass if existing, otherwise returns a mock value. + + :param field_meta: FieldMeta instance. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + + :returns: An arbitrary value. + + """ + if cls.is_ignored_type(field_meta.annotation): + return [None] + + for unwrapped_annotation in flatten_annotation(field_meta.annotation): + if unwrapped_annotation in (None, NoneType): + yield None + + elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): + yield CoverageContainer(literal_args) + + elif isinstance(unwrapped_annotation, EnumMeta): + yield CoverageContainer(list(unwrapped_annotation)) + + elif field_meta.constraints: + yield CoverageContainerCallable( + cls.get_constrained_field_value, + annotation=unwrapped_annotation, + field_meta=field_meta, + ) + + elif BaseFactory.is_factory_type(annotation=unwrapped_annotation): + yield CoverageContainer( + cls._get_or_create_factory(model=unwrapped_annotation).coverage( + **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), + ), + ) + + elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): + yield handle_collection_type_coverage(field_meta, origin, cls) + + elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): + yield create_random_string(cls.__random__, min_length=1, max_length=10) + + elif provider := cls.get_provider_map().get(unwrapped_annotation): + yield CoverageContainerCallable(provider) + + elif callable(unwrapped_annotation): + # if value is a callable we can try to naively call it. + # this will work for callables that do not require any parameters passed + with suppress(Exception): + yield CoverageContainerCallable(unwrapped_annotation) + else: + msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." + raise ParameterException( + msg, + ) + @classmethod def should_set_none_value(cls, field_meta: FieldMeta) -> bool: """Determine whether a given model field_meta should be set to None. @@ -777,6 +851,50 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: return result + @classmethod + def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: + """Process the given kwargs and generate values for the factory's model. + + :param kwargs: Any build kwargs. + + :returns: A dictionary of build results. + + """ + result: dict[str, Any] = {**kwargs} + generate_post: dict[str, PostGenerated] = {} + + for field_meta in cls.get_model_fields(): + field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) + + if cls.should_set_field_value(field_meta, **kwargs): + if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): + field_value = getattr(cls, field_meta.name) + if isinstance(field_value, Ignore): + continue + + if isinstance(field_value, Require) and field_meta.name not in kwargs: + msg = f"Require kwarg {field_meta.name} is missing" + raise MissingBuildKwargException(msg) + + if isinstance(field_value, PostGenerated): + generate_post[field_meta.name] = field_value + continue + + result[field_meta.name] = cls._handle_factory_field_coverage( + field_value=field_value, + field_build_parameters=field_build_parameters, + ) + continue + + result[field_meta.name] = CoverageContainer( + cls.get_field_value_coverage(field_meta, field_build_parameters=field_build_parameters), + ) + + for resolved in resolve_kwargs_coverage(result): + for field_name, post_generator in generate_post.items(): + resolved[field_name] = post_generator.to_value(field_name, resolved) + yield resolved + @classmethod def build(cls, **kwargs: Any) -> T: """Build an instance of the factory's __model__ @@ -801,6 +919,19 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]: """ return [cls.build(**kwargs) for _ in range(size)] + @classmethod + def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: + """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. + + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: A iterator of instances of type T. + + """ + for data in cls.process_kwargs_coverage(**kwargs): + instance = cls.__model__(**data) + yield cast("T", instance) + @classmethod def create_sync(cls, **kwargs: Any) -> T: """Build and persists synchronously a single model instance. diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index ad0c3764..c15f283b 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -1,17 +1,13 @@ from __future__ import annotations import sys +from types import NoneType from typing import TYPE_CHECKING, Any from typing_extensions import get_args, get_origin from polyfactory.constants import TYPE_MAPPING -from polyfactory.utils.predicates import ( - is_annotated, - is_new_type, - is_optional, - is_union, -) +from polyfactory.utils.predicates import is_annotated, is_new_type, is_optional, is_union if TYPE_CHECKING: from random import Random @@ -76,6 +72,28 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any: return annotation +def flatten_annotation(annotation: Any) -> list[Any]: + """Flattens an annotation. + :param annotation: A type annotation. + :returns: The flattened annotations. + """ + flat = [] + if is_new_type(annotation): + flat.extend(flatten_annotation(unwrap_new_type(annotation))) + elif is_optional(annotation): + flat.append(NoneType) + flat.extend(flatten_annotation(next(arg for arg in get_args(annotation) if arg not in (NoneType, None)))) + elif is_annotated(annotation): + flat.extend(flatten_annotation(get_args(annotation)[0])) + elif is_union(annotation): + for a in get_args(annotation): + flat.extend(flatten_annotation(a)) + else: + flat.append(annotation) + + return flat + + def unwrap_args(annotation: Any, random: Random) -> tuple[Any, ...]: """Unwrap the annotation and return any type args. diff --git a/polyfactory/utils/model_coverage.py b/polyfactory/utils/model_coverage.py new file mode 100644 index 00000000..bc4302f4 --- /dev/null +++ b/polyfactory/utils/model_coverage.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence +from typing import AbstractSet, Any, Generic, ParamSpec, Set, TypeVar + + +class CoverageContainerBase(ABC): + @abstractmethod + def next_value(self) -> Any: + ... + + @abstractmethod + def is_done(self) -> bool: + ... + + +class CoverageContainer(CoverageContainerBase): + def __init__(self, instances: Iterable[Any]) -> None: + self._pos = 0 + self._instances = list(instances) + if len(self._instances) == 0: + msg = "CoverageContainer must have at least one instance" + raise ValueError(msg) + + def next_value(self) -> Any: + value = self._instances[self._pos % len(self._instances)] + if isinstance(value, CoverageContainerBase): + result = value.next_value() + if value.is_done(): + # Only move onto the next instance if the sub-container is done + self._pos += 1 + return result + + self._pos += 1 + return value + + def is_done(self) -> bool: + return self._pos >= len(self._instances) + + def __repr__(self) -> str: + return f"CoverageContainer(instances={self._instances}, is_done={self.is_done()})" + + +T = TypeVar("T") +P = ParamSpec("P") + + +class CoverageContainerCallable(CoverageContainerBase, Generic[T]): + def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + self._func = func + self._args = args + self._kwargs = kwargs + + def next_value(self) -> T: + return self._func(*self._args, **self._kwargs) + + def is_done(self) -> bool: + return True + + +def _resolve_next(unresolved: Any) -> tuple[Any, bool]: # noqa: C901 + if isinstance(unresolved, CoverageContainerBase): + result, done = _resolve_next(unresolved.next_value()) + return result, unresolved.is_done() and done + + if isinstance(unresolved, Mapping): + result = {} + done_status = True + for key, value in unresolved.items(): + val_resolved, val_done = _resolve_next(value) + key_resolved, key_done = _resolve_next(key) + result[key_resolved] = val_resolved + done_status = done_status and val_done and key_done + return result, done_status + + if isinstance(unresolved, (tuple, MutableSequence)): + result = [] + done_status = True + for value in unresolved: + resolved, done = _resolve_next(value) + result.append(resolved) + done_status = done_status and done + if isinstance(unresolved, tuple): + result = tuple(result) + return result, done_status + + if isinstance(unresolved, Set): + result = type(unresolved)() + done_status = True + for value in unresolved: + resolved, done = _resolve_next(value) + result.add(resolved) + done_status = done_status and done + return result, done_status + + if issubclass(type(unresolved), AbstractSet): + result = type(unresolved)() + done_status = True + resolved_values = [] + for value in unresolved: + resolved, done = _resolve_next(value) + resolved_values.append(resolved) + done_status = done_status and done + return result.union(resolved_values), done_status + + return unresolved, True + + +def resolve_kwargs_coverage(kwargs: dict[str, Any]) -> Iterator[dict[str, Any]]: + done = False + while not done: + resolved, done = _resolve_next(kwargs) + yield resolved diff --git a/polyfactory/value_generators/complex_types.py b/polyfactory/value_generators/complex_types.py index 46d39df3..29def9f9 100644 --- a/polyfactory/value_generators/complex_types.py +++ b/polyfactory/value_generators/complex_types.py @@ -1,20 +1,11 @@ from __future__ import annotations -from typing import ( - TYPE_CHECKING, - AbstractSet, - Any, - Iterable, - MutableMapping, - MutableSequence, - Set, - Tuple, - cast, -) +from typing import TYPE_CHECKING, AbstractSet, Any, Iterable, MutableMapping, MutableSequence, Set, Tuple, cast from typing_extensions import is_typeddict from polyfactory.field_meta import FieldMeta +from polyfactory.utils.model_coverage import CoverageContainer if TYPE_CHECKING: from polyfactory.factories.base import BaseFactory @@ -60,3 +51,56 @@ def handle_collection_type(field_meta: FieldMeta, container_type: type, factory: msg = f"Unsupported container type: {container_type}" raise NotImplementedError(msg) + + +def handle_collection_type_coverage( + field_meta: FieldMeta, + container_type: type, + factory: type[BaseFactory[Any]], +) -> Any: + """Handle coverage generation of container types recursively. + + :param container_type: A type that can accept type arguments. + :param factory: A factory. + :param field_meta: A field meta instance. + + :returns: An unresolved built result. + """ + container = container_type() + if not field_meta.children: + return container + + if issubclass(container_type, MutableMapping) or is_typeddict(container_type): + for key_field_meta, value_field_meta in cast( + Iterable[Tuple[FieldMeta, FieldMeta]], + zip(field_meta.children[::2], field_meta.children[1::2]), + ): + key = CoverageContainer(factory.get_field_value_coverage(key_field_meta)) + value = CoverageContainer(factory.get_field_value_coverage(value_field_meta)) + container[key] = value + return container + + if issubclass(container_type, MutableSequence): + container_instance = container_type() + for subfield_meta in field_meta.children: + container_instance.extend(factory.get_field_value_coverage(subfield_meta)) + + return container_instance + + if issubclass(container_type, Set): + set_instance = container_type() + for subfield_meta in field_meta.children: + set_instance = set_instance.union(factory.get_field_value_coverage(subfield_meta)) + + return set_instance + + if issubclass(container_type, AbstractSet): + return container.union(handle_collection_type_coverage(field_meta, set, factory)) + + if issubclass(container_type, tuple): + return container_type( + CoverageContainer(factory.get_field_value_coverage(subfield_meta)) for subfield_meta in field_meta.children + ) + + msg = f"Unsupported container type: {container_type}" + raise NotImplementedError(msg) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py new file mode 100644 index 00000000..3930f8af --- /dev/null +++ b/tests/test_type_coverage_generation.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from datetime import date +from typing import Literal, Sequence +from uuid import UUID + +import pytest +from pydantic import BaseModel +from typing_extensions import TypedDict + +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.factories.typed_dict_factory import TypedDictFactory + + +class Stringy(BaseModel): + string: str + + +class Numberish(BaseModel): + number: int | float + + +class Datelike(BaseModel): + birthday: date + + +class Profile(BaseModel): + name: Stringy + high_score: Numberish + dob: Datelike + data: Stringy | Datelike | Numberish + + +class Tuple(BaseModel): + tuple_: tuple[int | str, tuple[Numberish, int]] + + +class Collective(BaseModel): + set_: set[int | str] + list_: list[int | str] + frozenset_: frozenset[int | str] + sequence_: Sequence[int | str] + + +class Literally(BaseModel): + literal: Literal["a", "b", 1, 2] + + +class Thesaurus(BaseModel): + dict_simple: dict[str, int] + dict_more_key_types: dict[str | int | float, int | str] + dict_more_value_types: dict[str, int | str] + + +class TypedThesaurus(TypedDict): + number: int + string: str + union: int | str + collection: list[int | str] + + +class TypedThesaurusModel(BaseModel): + thesaurus: TypedThesaurus + + +def test_coverage_count() -> None: + class ProfileFactory(ModelFactory[Profile]): + __model__ = Profile + + results = list(ProfileFactory.coverage()) + + assert len(results) == 4 + + for result in results: + assert isinstance(result, Profile) + + +def test_coverage_tuple() -> None: + class TupleFactory(ModelFactory[Tuple]): + __model__ = Tuple + + results = list(TupleFactory.coverage()) + + assert len(results) == 2 + + a0, (b0, c0) = results[0].tuple_ + a1, (b1, c1) = results[1].tuple_ + + assert isinstance(a0, int) and isinstance(b0, Numberish) and isinstance(b0.number, int) and isinstance(c0, int) + assert isinstance(a1, str) and isinstance(b1, Numberish) and isinstance(b1.number, float) and isinstance(c1, int) + + +def test_coverage_collection() -> None: + class CollectiveFactory(ModelFactory[Collective]): + __model__ = Collective + + results = list(CollectiveFactory.coverage()) + + assert len(results) == 1 + + result = results[0] + + assert len(result.set_) == 2 + assert len(result.list_) == 2 + assert len(result.frozenset_) == 2 + assert len(result.sequence_) == 2 + + v0, v1 = result.set_ + assert {type(v0), type(v1)} == {int, str} + v0, v1 = result.list_ + assert {type(v0), type(v1)} == {int, str} + v0, v1 = result.frozenset_ + assert {type(v0), type(v1)} == {int, str} + v0, v1 = result.sequence_ + assert {type(v0), type(v1)} == {int, str} + + +def test_coverage_literal() -> None: + class LiterallyFactory(ModelFactory[Literally]): + __model__ = Literally + + results = list(LiterallyFactory.coverage()) + + assert len(results) == 4 + + assert results[0].literal == "a" + assert results[1].literal == "b" + assert results[2].literal == 1 + assert results[3].literal == 2 + + +def test_coverage_dict() -> None: + class ThesaurusFactory(ModelFactory[Thesaurus]): + __model__ = Thesaurus + + results = list(ThesaurusFactory.coverage()) + + assert len(results) == 3 + + +@pytest.mark.skip(reason="Does not support recursive types yet.") +def test_coverage_recursive() -> None: + class Recursive(BaseModel): + r: Recursive | None + + class RecursiveFactory(ModelFactory[Recursive]): + __model__ = Recursive + + results = list(RecursiveFactory.coverage()) + assert len(results) == 2 + + +def test_coverage_typed_dict() -> None: + class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): + __model__ = TypedThesaurus + + results = list(TypedThesaurusFactory.coverage()) + + assert len(results) == 2 + + example = TypedThesaurusFactory.build() + for result in results: + assert result.keys() == example.keys() + + +def test_coverage_typed_dict_field() -> None: + class TypedThesaurusModelFactory(ModelFactory[TypedThesaurusModel]): + __model__ = TypedThesaurusModel + + class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): + __model__ = TypedThesaurus + + results = list(TypedThesaurusModelFactory.coverage()) + + assert len(results) == 2 + + example = TypedThesaurusFactory.build() + + for result in results: + assert result.thesaurus.keys() == example.keys() + + +def test_coverage_values_unique() -> None: + class Unique(BaseModel): + uuid: UUID + data: int | str + + class UniqueFactory(ModelFactory[Unique]): + __model__ = Unique + + results = list(UniqueFactory.coverage()) + + assert len(results) == 2 + assert results[0].uuid != results[1].uuid From 84c9c51ed8f9076c10e4c77f05e49dcb25e0734c Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 01:29:20 +0000 Subject: [PATCH 02/37] fix: revert change to .pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71f6d84a..99cf36fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: "3.10" + python: "3.11" repos: - repo: https://github.com/compilerla/conventional-pre-commit rev: v2.4.0 From e97b1c18f62cbdf29caa0043c39ce02f8822de61 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 01:54:24 +0000 Subject: [PATCH 03/37] fix: Update NoneType importing for older python versions --- polyfactory/factories/base.py | 6 +++++- polyfactory/utils/helpers.py | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 87c0db6f..dee637f7 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -22,7 +22,11 @@ from os.path import realpath from pathlib import Path from random import Random -from types import NoneType + +try: + from types import NoneType +except ImportError: + NoneType = type(None) # type: ignore[misc,assignment] from typing import ( TYPE_CHECKING, Any, diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index c15f283b..66fd55cb 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -1,9 +1,13 @@ from __future__ import annotations import sys -from types import NoneType from typing import TYPE_CHECKING, Any +try: + from types import NoneType +except ImportError: + NoneType = type(None) # type: ignore[misc,assignment] + from typing_extensions import get_args, get_origin from polyfactory.constants import TYPE_MAPPING @@ -47,7 +51,7 @@ def unwrap_optional(annotation: Any) -> Any: :returns: A type annotation """ while is_optional(annotation): - annotation = next(arg for arg in get_args(annotation) if arg not in (type(None), None)) + annotation = next(arg for arg in get_args(annotation) if arg not in (NoneType, None)) return annotation From 8df692fbcf41e888309f28560f223e5ac460c65a Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 02:11:10 +0000 Subject: [PATCH 04/37] fix: apply sourcery refactor --- polyfactory/utils/model_coverage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polyfactory/utils/model_coverage.py b/polyfactory/utils/model_coverage.py index bc4302f4..023d5d61 100644 --- a/polyfactory/utils/model_coverage.py +++ b/polyfactory/utils/model_coverage.py @@ -19,7 +19,7 @@ class CoverageContainer(CoverageContainerBase): def __init__(self, instances: Iterable[Any]) -> None: self._pos = 0 self._instances = list(instances) - if len(self._instances) == 0: + if not self._instances: msg = "CoverageContainer must have at least one instance" raise ValueError(msg) From 973417238749966e39670a5f993a061e8143c540 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 03:43:38 +0000 Subject: [PATCH 05/37] fix: import ParamSpec from typing_extensions --- polyfactory/utils/model_coverage.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/polyfactory/utils/model_coverage.py b/polyfactory/utils/model_coverage.py index 023d5d61..e73fd1b2 100644 --- a/polyfactory/utils/model_coverage.py +++ b/polyfactory/utils/model_coverage.py @@ -2,7 +2,9 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence -from typing import AbstractSet, Any, Generic, ParamSpec, Set, TypeVar +from typing import AbstractSet, Any, Generic, Set, TypeVar + +from typing_extensions import ParamSpec class CoverageContainerBase(ABC): From 373fea46689aec2a056029ed9aa3e11eeead2975 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 04:01:49 +0000 Subject: [PATCH 06/37] fix: Skip tests on py versions < 3.10 --- .pre-commit-config.yaml | 2 +- tests/test_type_coverage_generation.py | 104 +++++++++++++------------ 2 files changed, 57 insertions(+), 49 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99cf36fa..71f6d84a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: "3.11" + python: "3.10" repos: - repo: https://github.com/compilerla/conventional-pre-commit rev: v2.4.0 diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index 3930f8af..9d7a1932 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from datetime import date from typing import Literal, Sequence from uuid import UUID @@ -12,58 +13,23 @@ from polyfactory.factories.typed_dict_factory import TypedDictFactory -class Stringy(BaseModel): - string: str - - -class Numberish(BaseModel): - number: int | float - - -class Datelike(BaseModel): - birthday: date - - -class Profile(BaseModel): - name: Stringy - high_score: Numberish - dob: Datelike - data: Stringy | Datelike | Numberish - - -class Tuple(BaseModel): - tuple_: tuple[int | str, tuple[Numberish, int]] - - -class Collective(BaseModel): - set_: set[int | str] - list_: list[int | str] - frozenset_: frozenset[int | str] - sequence_: Sequence[int | str] - - -class Literally(BaseModel): - literal: Literal["a", "b", 1, 2] - - -class Thesaurus(BaseModel): - dict_simple: dict[str, int] - dict_more_key_types: dict[str | int | float, int | str] - dict_more_value_types: dict[str, int | str] - - -class TypedThesaurus(TypedDict): - number: int - string: str - union: int | str - collection: list[int | str] +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_coverage_count() -> None: + class Stringy(BaseModel): + string: str + class Numberish(BaseModel): + number: int | float -class TypedThesaurusModel(BaseModel): - thesaurus: TypedThesaurus + class Datelike(BaseModel): + birthday: date + class Profile(BaseModel): + name: Stringy + high_score: Numberish + dob: Datelike + data: Stringy | Datelike | Numberish -def test_coverage_count() -> None: class ProfileFactory(ModelFactory[Profile]): __model__ = Profile @@ -75,7 +41,14 @@ class ProfileFactory(ModelFactory[Profile]): assert isinstance(result, Profile) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_tuple() -> None: + class Numberish(BaseModel): + number: int | float + + class Tuple(BaseModel): + tuple_: tuple[int | str, tuple[Numberish, int]] + class TupleFactory(ModelFactory[Tuple]): __model__ = Tuple @@ -90,7 +63,14 @@ class TupleFactory(ModelFactory[Tuple]): assert isinstance(a1, str) and isinstance(b1, Numberish) and isinstance(b1.number, float) and isinstance(c1, int) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_collection() -> None: + class Collective(BaseModel): + set_: set[int | str] + list_: list[int | str] + frozenset_: frozenset[int | str] + sequence_: Sequence[int | str] + class CollectiveFactory(ModelFactory[Collective]): __model__ = Collective @@ -115,7 +95,11 @@ class CollectiveFactory(ModelFactory[Collective]): assert {type(v0), type(v1)} == {int, str} +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_literal() -> None: + class Literally(BaseModel): + literal: Literal["a", "b", 1, 2] + class LiterallyFactory(ModelFactory[Literally]): __model__ = Literally @@ -130,6 +114,11 @@ class LiterallyFactory(ModelFactory[Literally]): def test_coverage_dict() -> None: + class Thesaurus(BaseModel): + dict_simple: dict[str, int] + dict_more_key_types: dict[str | int | float, int | str] + dict_more_value_types: dict[str, int | str] + class ThesaurusFactory(ModelFactory[Thesaurus]): __model__ = Thesaurus @@ -139,6 +128,7 @@ class ThesaurusFactory(ModelFactory[Thesaurus]): @pytest.mark.skip(reason="Does not support recursive types yet.") +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_recursive() -> None: class Recursive(BaseModel): r: Recursive | None @@ -150,7 +140,14 @@ class RecursiveFactory(ModelFactory[Recursive]): assert len(results) == 2 +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_typed_dict() -> None: + class TypedThesaurus(TypedDict): + number: int + string: str + union: int | str + collection: list[int | str] + class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): __model__ = TypedThesaurus @@ -163,7 +160,17 @@ class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): assert result.keys() == example.keys() +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_typed_dict_field() -> None: + class TypedThesaurus(TypedDict): + number: int + string: str + union: int | str + collection: list[int | str] + + class TypedThesaurusModel(BaseModel): + thesaurus: TypedThesaurus + class TypedThesaurusModelFactory(ModelFactory[TypedThesaurusModel]): __model__ = TypedThesaurusModel @@ -180,6 +187,7 @@ class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): assert result.thesaurus.keys() == example.keys() +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_values_unique() -> None: class Unique(BaseModel): uuid: UUID From ae38e54ef3ce6ea92662b0098c55fe61a3ba06f6 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 04:06:10 +0000 Subject: [PATCH 07/37] fix: revert changes to .pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71f6d84a..99cf36fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: "3.10" + python: "3.11" repos: - repo: https://github.com/compilerla/conventional-pre-commit rev: v2.4.0 From 1fb060823a9f8d18d5c3b5d102f99c36ef709328 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 14:11:07 +1000 Subject: [PATCH 08/37] chore: Create devcontainer.json --- .devcontainer/devcontainer.json | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .devcontainer/devcontainer.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..2c9de588 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,6 @@ +{ + "image": "mcr.microsoft.com/devcontainers/universal:2", + "features": { + "ghcr.io/devcontainers/features/python:1": {"version":"3.11"} + } +} From f2289d0bc91930cb637fac37c9e2eae8fe46d3a4 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 06:43:31 +0000 Subject: [PATCH 09/37] fix: remove .devcontainer dir --- .devcontainer/devcontainer.json | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 .devcontainer/devcontainer.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index 2c9de588..00000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "image": "mcr.microsoft.com/devcontainers/universal:2", - "features": { - "ghcr.io/devcontainers/features/python:1": {"version":"3.11"} - } -} From d9adc279a77f8bc037004ca9601d2311a9fcf0bf Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 06:46:19 +0000 Subject: [PATCH 10/37] fix: Add missing test skip for older python versions --- tests/test_type_coverage_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index 9d7a1932..eec2518d 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -113,6 +113,7 @@ class LiterallyFactory(ModelFactory[Literally]): assert results[3].literal == 2 +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_dict() -> None: class Thesaurus(BaseModel): dict_simple: dict[str, int] From 20f481379480df9ba0f1cf69af64159ae8fdf4ba Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 06:57:26 +0000 Subject: [PATCH 11/37] test: Add test for post generated in coverage generation --- tests/test_type_coverage_generation.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index eec2518d..076b4b3f 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict +from polyfactory.decorators import post_generated from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.factories.typed_dict_factory import TypedDictFactory @@ -201,3 +202,23 @@ class UniqueFactory(ModelFactory[Unique]): assert len(results) == 2 assert results[0].uuid != results[1].uuid + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_coverage_post_generated() -> None: + class Model(BaseModel): + i: int + j: int + + class Factory(ModelFactory[Model]): + __model__ = Model + + @post_generated + @classmethod + def i(cls, j: int) -> int: + return j + 10 + + results = list(Factory.coverage()) + assert len(results) == 1 + + assert results[0].i == results[0].j + 10 From 7f713397ecf0a540335787012111874608f1b31a Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 4 Oct 2023 03:05:12 +0000 Subject: [PATCH 12/37] test: Simplify type coverage generation tests --- tests/test_type_coverage_generation.py | 93 ++++++++++---------------- 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index 076b4b3f..1f89369c 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -1,37 +1,27 @@ from __future__ import annotations -import sys +from dataclasses import dataclass from datetime import date from typing import Literal, Sequence from uuid import UUID import pytest -from pydantic import BaseModel from typing_extensions import TypedDict from polyfactory.decorators import post_generated -from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.factories.dataclass_factory import DataclassFactory from polyfactory.factories.typed_dict_factory import TypedDictFactory -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_count() -> None: - class Stringy(BaseModel): - string: str - - class Numberish(BaseModel): - number: int | float - - class Datelike(BaseModel): - birthday: date - - class Profile(BaseModel): - name: Stringy - high_score: Numberish - dob: Datelike - data: Stringy | Datelike | Numberish - - class ProfileFactory(ModelFactory[Profile]): + @dataclass + class Profile: + name: str + high_score: int | float + dob: date + data: str | date | int | float + + class ProfileFactory(DataclassFactory[Profile]): __model__ = Profile results = list(ProfileFactory.coverage()) @@ -42,15 +32,12 @@ class ProfileFactory(ModelFactory[Profile]): assert isinstance(result, Profile) -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_tuple() -> None: - class Numberish(BaseModel): - number: int | float + @dataclass + class Tuple: + tuple_: tuple[int | str, tuple[int | float, int]] - class Tuple(BaseModel): - tuple_: tuple[int | str, tuple[Numberish, int]] - - class TupleFactory(ModelFactory[Tuple]): + class TupleFactory(DataclassFactory[Tuple]): __model__ = Tuple results = list(TupleFactory.coverage()) @@ -60,19 +47,19 @@ class TupleFactory(ModelFactory[Tuple]): a0, (b0, c0) = results[0].tuple_ a1, (b1, c1) = results[1].tuple_ - assert isinstance(a0, int) and isinstance(b0, Numberish) and isinstance(b0.number, int) and isinstance(c0, int) - assert isinstance(a1, str) and isinstance(b1, Numberish) and isinstance(b1.number, float) and isinstance(c1, int) + assert isinstance(a0, int) and isinstance(b0, int) and isinstance(c0, int) + assert isinstance(a1, str) and isinstance(b1, float) and isinstance(c1, int) -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_collection() -> None: - class Collective(BaseModel): + @dataclass + class Collective: set_: set[int | str] list_: list[int | str] frozenset_: frozenset[int | str] sequence_: Sequence[int | str] - class CollectiveFactory(ModelFactory[Collective]): + class CollectiveFactory(DataclassFactory[Collective]): __model__ = Collective results = list(CollectiveFactory.coverage()) @@ -96,12 +83,12 @@ class CollectiveFactory(ModelFactory[Collective]): assert {type(v0), type(v1)} == {int, str} -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_literal() -> None: - class Literally(BaseModel): + @dataclass + class Literally: literal: Literal["a", "b", 1, 2] - class LiterallyFactory(ModelFactory[Literally]): + class LiterallyFactory(DataclassFactory[Literally]): __model__ = Literally results = list(LiterallyFactory.coverage()) @@ -114,14 +101,14 @@ class LiterallyFactory(ModelFactory[Literally]): assert results[3].literal == 2 -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_dict() -> None: - class Thesaurus(BaseModel): + @dataclass + class Thesaurus: dict_simple: dict[str, int] dict_more_key_types: dict[str | int | float, int | str] dict_more_value_types: dict[str, int | str] - class ThesaurusFactory(ModelFactory[Thesaurus]): + class ThesaurusFactory(DataclassFactory[Thesaurus]): __model__ = Thesaurus results = list(ThesaurusFactory.coverage()) @@ -130,19 +117,18 @@ class ThesaurusFactory(ModelFactory[Thesaurus]): @pytest.mark.skip(reason="Does not support recursive types yet.") -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_recursive() -> None: - class Recursive(BaseModel): + @dataclass + class Recursive: r: Recursive | None - class RecursiveFactory(ModelFactory[Recursive]): + class RecursiveFactory(DataclassFactory[Recursive]): __model__ = Recursive results = list(RecursiveFactory.coverage()) assert len(results) == 2 -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_typed_dict() -> None: class TypedThesaurus(TypedDict): number: int @@ -162,7 +148,6 @@ class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): assert result.keys() == example.keys() -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_typed_dict_field() -> None: class TypedThesaurus(TypedDict): number: int @@ -170,32 +155,26 @@ class TypedThesaurus(TypedDict): union: int | str collection: list[int | str] - class TypedThesaurusModel(BaseModel): - thesaurus: TypedThesaurus - - class TypedThesaurusModelFactory(ModelFactory[TypedThesaurusModel]): - __model__ = TypedThesaurusModel - class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): __model__ = TypedThesaurus - results = list(TypedThesaurusModelFactory.coverage()) + results = list(TypedThesaurusFactory.coverage()) assert len(results) == 2 example = TypedThesaurusFactory.build() for result in results: - assert result.thesaurus.keys() == example.keys() + assert result.keys() == example.keys() -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_values_unique() -> None: - class Unique(BaseModel): + @dataclass + class Unique: uuid: UUID data: int | str - class UniqueFactory(ModelFactory[Unique]): + class UniqueFactory(DataclassFactory[Unique]): __model__ = Unique results = list(UniqueFactory.coverage()) @@ -204,13 +183,13 @@ class UniqueFactory(ModelFactory[Unique]): assert results[0].uuid != results[1].uuid -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_post_generated() -> None: - class Model(BaseModel): + @dataclass + class Model: i: int j: int - class Factory(ModelFactory[Model]): + class Factory(DataclassFactory[Model]): __model__ = Model @post_generated From f93a4988167b95b03963a9ecbe58b3c312893fa0 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 4 Oct 2023 03:12:20 +0000 Subject: [PATCH 13/37] test: Add back min python3.10 version condition --- tests/test_type_coverage_generation.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index 1f89369c..cc3a6380 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from dataclasses import dataclass from datetime import date from typing import Literal, Sequence @@ -13,6 +14,7 @@ from polyfactory.factories.typed_dict_factory import TypedDictFactory +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_count() -> None: @dataclass class Profile: @@ -32,6 +34,7 @@ class ProfileFactory(DataclassFactory[Profile]): assert isinstance(result, Profile) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_tuple() -> None: @dataclass class Tuple: @@ -51,6 +54,7 @@ class TupleFactory(DataclassFactory[Tuple]): assert isinstance(a1, str) and isinstance(b1, float) and isinstance(c1, int) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_collection() -> None: @dataclass class Collective: @@ -83,6 +87,7 @@ class CollectiveFactory(DataclassFactory[Collective]): assert {type(v0), type(v1)} == {int, str} +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_literal() -> None: @dataclass class Literally: @@ -101,6 +106,7 @@ class LiterallyFactory(DataclassFactory[Literally]): assert results[3].literal == 2 +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_dict() -> None: @dataclass class Thesaurus: @@ -117,6 +123,7 @@ class ThesaurusFactory(DataclassFactory[Thesaurus]): @pytest.mark.skip(reason="Does not support recursive types yet.") +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_recursive() -> None: @dataclass class Recursive: @@ -129,6 +136,7 @@ class RecursiveFactory(DataclassFactory[Recursive]): assert len(results) == 2 +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_typed_dict() -> None: class TypedThesaurus(TypedDict): number: int @@ -148,6 +156,7 @@ class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): assert result.keys() == example.keys() +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_typed_dict_field() -> None: class TypedThesaurus(TypedDict): number: int @@ -168,6 +177,7 @@ class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): assert result.keys() == example.keys() +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_values_unique() -> None: @dataclass class Unique: @@ -183,6 +193,7 @@ class UniqueFactory(DataclassFactory[Unique]): assert results[0].uuid != results[1].uuid +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_post_generated() -> None: @dataclass class Model: From e306abc6c6c1c1814679e08ff15ac88bdc8044a1 Mon Sep 17 00:00:00 2001 From: Jacob Coffee Date: Sat, 7 Oct 2023 12:41:17 -0500 Subject: [PATCH 14/37] fix(infra): update makefile (#399) --- Makefile | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 826eedfd..c23d5282 100644 --- a/Makefile +++ b/Makefile @@ -5,11 +5,11 @@ SHELL := /bin/bash .DEFAULT_GOAL:=help .ONESHELL: -USING_PDM = $(shell grep "tool.pdm" pyproject.toml && echo "yes") -ENV_PREFIX = $(shell python3 -c "if __import__('pathlib').Path('.venv/bin/pip').exists(): print('.venv/bin/')") -VENV_EXISTS = $(shell python3 -c "if __import__('pathlib').Path('.venv/bin/activate').exists(): print('yes')") -PDM_OPTS ?= -PDM ?= pdm $(PDM_OPTS) +USING_PDM = $(shell grep "tool.pdm" pyproject.toml && echo "yes") +ENV_PREFIX := $(shell if [ -d .venv ]; then echo ".venv/bin/"; fi) +VENV_EXISTS := $(shell if [ -d .venv ]; then echo "yes"; fi) +PDM_OPTS ?= +PDM ?= pdm $(PDM_OPTS) .EXPORT_ALL_VARIABLES: From 8752b817692d20f385635864830e6fb3c336acca Mon Sep 17 00:00:00 2001 From: Andrew Truong <40660973+adhtruong@users.noreply.github.com> Date: Wed, 11 Oct 2023 21:15:03 +0100 Subject: [PATCH 15/37] docs: Install all dependencies for docs build (#404) --- .github/workflows/docs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 702ae687..9992ed7a 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -27,7 +27,7 @@ jobs: cache: true - name: Install dependencies - run: pdm install -G:docs + run: pdm install -G:all - name: Fetch gh pages run: git fetch origin gh-pages --depth=1 From f84b77127ccc770e5de0db8d4a8041ad294964ee Mon Sep 17 00:00:00 2001 From: guacs <126393040+guacs@users.noreply.github.com> Date: Sun, 15 Oct 2023 00:06:47 +0530 Subject: [PATCH 16/37] fix: decouple the handling of collection length configuration from `FieldMeta` (#407) --- polyfactory/factories/base.py | 62 ++++++---- polyfactory/field_meta.py | 6 +- polyfactory/utils/helpers.py | 33 +++++- tests/test_attrs_factory.py | 57 +-------- tests/test_beanie_factory.py | 15 --- tests/test_collection_length.py | 108 ++++++++++++++++++ tests/test_dataclass_factory.py | 89 +-------------- tests/test_msgspec_factory.py | 52 --------- tests/test_odmantic_factory.py | 58 +--------- ...t_passing_build_args_to_child_factories.py | 67 +---------- tests/test_typeddict_factory.py | 18 --- 11 files changed, 182 insertions(+), 383 deletions(-) create mode 100644 tests/test_collection_length.py diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index e9a308a0..62d224f5 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -57,6 +57,16 @@ from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage +from polyfactory.utils.helpers import get_collection_type, unwrap_annotation, unwrap_args, unwrap_optional +from polyfactory.utils.predicates import ( + get_type_origin, + is_any, + is_literal, + is_optional, + is_safe_subclass, + is_union, +) +from polyfactory.value_generators.complex_types import handle_collection_type from polyfactory.value_generators.constrained_collections import ( handle_constrained_collection, handle_constrained_mapping, @@ -459,7 +469,7 @@ def create_factory( ) @classmethod - def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: # noqa: C901, PLR0911, PLR0912 + def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: # noqa: C901, PLR0911 try: constraints = cast("Constraints", field_meta.constraints) if is_safe_subclass(annotation, float): @@ -508,21 +518,15 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> pattern=constraints.get("pattern"), ) - if ( - is_safe_subclass(annotation, set) - or is_safe_subclass(annotation, list) - or is_safe_subclass(annotation, frozenset) - or is_safe_subclass(annotation, tuple) - ): - collection_type: type[list | set | tuple | frozenset] - if is_safe_subclass(annotation, list): - collection_type = list - elif is_safe_subclass(annotation, set): - collection_type = set - elif is_safe_subclass(annotation, tuple): - collection_type = tuple - else: - collection_type = frozenset + with suppress(ValueError): + collection_type = get_collection_type(annotation) + if collection_type == dict: + return handle_constrained_mapping( + factory=cls, + field_meta=field_meta, + min_items=constraints.get("min_length"), + max_items=constraints.get("max_length"), + ) return handle_constrained_collection( collection_type=collection_type, # type: ignore[type-var] @@ -534,14 +538,6 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> unique_items=constraints.get("unique_items", False), ) - if is_safe_subclass(annotation, dict): - return handle_constrained_mapping( - factory=cls, - field_meta=field_meta, - min_items=constraints.get("min_length"), - max_items=constraints.get("max_length"), - ) - if is_safe_subclass(annotation, date): return handle_constrained_date( faker=cls.__faker__, @@ -612,6 +608,24 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 return factory.batch(size=batch_size) if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): + if cls.__randomize_collection_length__: + collection_type = get_collection_type(unwrapped_annotation) + if collection_type != dict: + return handle_constrained_collection( + collection_type=collection_type, # type: ignore[type-var] + factory=cls, + item_type=Any, + field_meta=field_meta.children[0] if field_meta.children else field_meta, + min_items=cls.__min_collection_length__, + max_items=cls.__max_collection_length__, + ) + return handle_constrained_mapping( + factory=cls, + field_meta=field_meta, + min_items=cls.__min_collection_length__, + max_items=cls.__max_collection_length__, + ) + return handle_collection_type(field_meta, origin, cls) if is_union(field_meta.annotation) and field_meta.children: diff --git a/polyfactory/field_meta.py b/polyfactory/field_meta.py index c0c4d9e8..a0d8bc91 100644 --- a/polyfactory/field_meta.py +++ b/polyfactory/field_meta.py @@ -150,11 +150,7 @@ def from_type( ) if field.type_args and not field.children: - if randomize_collection_length: - number_of_args = random.randint(min_collection_length, max_collection_length) - else: - number_of_args = 1 - + number_of_args = 1 extended_type_args = CollectionExtender.extend_type_args(field.annotation, field.type_args, number_of_args) field.children = [ FieldMeta.from_type( diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index 66fd55cb..db174354 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Mapping try: from types import NoneType @@ -11,7 +11,13 @@ from typing_extensions import get_args, get_origin from polyfactory.constants import TYPE_MAPPING -from polyfactory.utils.predicates import is_annotated, is_new_type, is_optional, is_union +from polyfactory.utils.predicates import ( + is_annotated, + is_new_type, + is_optional, + is_safe_subclass, + is_union, +) if TYPE_CHECKING: from random import Random @@ -152,3 +158,26 @@ def normalize_annotation(annotation: Any, random: Random) -> Any: return origin[args] if origin is not type else annotation return origin + + +def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict]: + """Get the collection type from the annotation. + + :param annotation: A type annotation. + + :returns: The collection type. + """ + + if is_safe_subclass(annotation, list): + return list + if is_safe_subclass(annotation, Mapping): + return dict + if is_safe_subclass(annotation, tuple): + return tuple + if is_safe_subclass(annotation, set): + return set + if is_safe_subclass(annotation, frozenset): + return frozenset + + msg = f"Unknown collection type - {annotation}" + raise ValueError(msg) diff --git a/tests/test_attrs_factory.py b/tests/test_attrs_factory.py index 270ae3e5..f4f56576 100644 --- a/tests/test_attrs_factory.py +++ b/tests/test_attrs_factory.py @@ -1,7 +1,7 @@ import datetime as dt from decimal import Decimal from enum import Enum -from typing import Any, Dict, FrozenSet, Generic, List, Set, Tuple, TypeVar +from typing import Any, Dict, Generic, List, Tuple, TypeVar from uuid import UUID import attrs @@ -135,61 +135,6 @@ class FooFactory(AttrsFactory[Foo]): assert foo == Foo(foo.aliased) -@pytest.mark.parametrize("type_", (Set, FrozenSet, List)) -def test_variable_length(type_: Any) -> None: - @define - class Foo: - items: type_[int] - - class FooFactory(AttrsFactory[Foo]): - __model__ = Foo - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - foo = FooFactory.build() - assert len(foo.items) == 3 - - -def test_variable_length__dict() -> None: - @define - class Foo: - items: Dict[int, float] - - class FooFactory(AttrsFactory[Foo]): - __model__ = Foo - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - foo = FooFactory.build() - assert len(foo.items) == 3 - - -def test_variable_length__tuple() -> None: - @define - class Foo: - items: Tuple[int, ...] - - class FooFactory(AttrsFactory[Foo]): - __model__ = Foo - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - foo = FooFactory.build() - assert len(foo.items) == 3 - - def test_with_generics() -> None: T = TypeVar("T") diff --git a/tests/test_beanie_factory.py b/tests/test_beanie_factory.py index 6582bc9c..280ecb63 100644 --- a/tests/test_beanie_factory.py +++ b/tests/test_beanie_factory.py @@ -80,18 +80,3 @@ async def test_beanie_persistence_of_multiple_instances(beanie_init: Callable) - async def test_beanie_links(beanie_init: Callable) -> None: result = await MyOtherFactory.create_async() assert isinstance(result.document, MyDocument) - - -def test_variable_length(beanie_init: Callable) -> None: - class MyVariableFactory(BeanieDocumentFactory): - __model__ = MyDocument - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyVariableFactory.build() - - assert len(result.siblings) == 3 diff --git a/tests/test_collection_length.py b/tests/test_collection_length.py new file mode 100644 index 00000000..b8563165 --- /dev/null +++ b/tests/test_collection_length.py @@ -0,0 +1,108 @@ +from typing import Dict, List, Optional, Set, Tuple + +import pytest +from pydantic.dataclasses import dataclass + +from polyfactory.factories import DataclassFactory + +MIN_MAX_PARAMETERS = ((10, 15), (20, 25), (30, 40), (40, 50)) + + +@pytest.mark.parametrize("type_", (List, Set)) +@pytest.mark.parametrize(("min_val", "max_val"), MIN_MAX_PARAMETERS) +def test_collection_length_with_list(min_val: int, max_val: int, type_: type) -> None: + @dataclass + class Foo: + foo: type_[int] # type: ignore + + class FooFactory(DataclassFactory[Foo]): + __model__ = Foo + __randomize_collection_length__ = True + __min_collection_length__ = min_val + __max_collection_length__ = max_val + + foo = FooFactory.build() + + assert len(foo.foo) >= min_val, len(foo.foo) + assert len(foo.foo) <= max_val, len(foo.foo) + + +@pytest.mark.parametrize(("min_val", "max_val"), MIN_MAX_PARAMETERS) +def test_collection_length_with_tuple(min_val: int, max_val: int) -> None: + @dataclass + class Foo: + foo: Tuple[int, ...] + + class FooFactory(DataclassFactory[Foo]): + __model__ = Foo + __randomize_collection_length__ = True + __min_collection_length__ = min_val + __max_collection_length__ = max_val + + foo = FooFactory.build() + + assert len(foo.foo) >= min_val, len(foo.foo) + assert len(foo.foo) <= max_val, len(foo.foo) + + +@pytest.mark.parametrize(("min_val", "max_val"), MIN_MAX_PARAMETERS) +def test_collection_length_with_dict(min_val: int, max_val: int) -> None: + @dataclass + class Foo: + foo: Dict[int, int] + + class FooFactory(DataclassFactory[Foo]): + __model__ = Foo + + __randomize_collection_length__ = True + __min_collection_length__ = min_val + __max_collection_length__ = max_val + + foo = FooFactory.build() + + assert len(foo.foo) >= min_val, len(foo.foo) + assert len(foo.foo) <= max_val, len(foo.foo) + + +@pytest.mark.parametrize(("min_val", "max_val"), MIN_MAX_PARAMETERS) +def test_collection_length_with_optional_not_allowed(min_val: int, max_val: int) -> None: + @dataclass + class Foo: + foo: Optional[List[int]] + + class FooFactory(DataclassFactory[Foo]): + __model__ = Foo + + __allow_none_optionals__ = False + __randomize_collection_length__ = True + __min_collection_length__ = min_val + __max_collection_length__ = max_val + + foo = FooFactory.build() + + assert foo.foo is not None + assert len(foo.foo) >= min_val, len(foo.foo) + assert len(foo.foo) <= max_val, len(foo.foo) + + +@pytest.mark.parametrize(("min_val", "max_val"), MIN_MAX_PARAMETERS) +def test_collection_length_with_optional_allowed(min_val: int, max_val: int) -> None: + @dataclass + class Foo: + foo: Optional[List[int]] + + class FooFactory(DataclassFactory[Foo]): + __model__ = Foo + + __randomize_collection_length__ = True + __min_collection_length__ = min_val + __max_collection_length__ = max_val + + for _ in range(10): + foo = FooFactory.build() + + if foo.foo is None: + continue + + assert len(foo.foo) >= min_val, len(foo.foo) + assert len(foo.foo) <= max_val, len(foo.foo) diff --git a/tests/test_dataclass_factory.py b/tests/test_dataclass_factory.py index 536a86a6..276b81b5 100644 --- a/tests/test_dataclass_factory.py +++ b/tests/test_dataclass_factory.py @@ -1,10 +1,9 @@ from dataclasses import dataclass as vanilla_dataclass from dataclasses import field from types import ModuleType -from typing import Callable, Dict, List, Optional, Set, Tuple +from typing import Callable, Dict, List, Optional, Tuple from unittest.mock import ANY -from pydantic.dataclasses import Field # type: ignore from pydantic.dataclasses import dataclass as pydantic_dataclass from polyfactory.factories import DataclassFactory @@ -38,7 +37,6 @@ class PydanticDC: id: int name: str list_field: List[Dict[str, int]] - constrained_field: int = Field(ge=100) field_of_some_value: Optional[int] = field(default_factory=lambda: 0) class MyFactory(DataclassFactory[PydanticDC]): @@ -52,7 +50,6 @@ class MyFactory(DataclassFactory[PydanticDC]): assert result.list_field assert result.list_field[0] assert [isinstance(value, int) for value in result.list_field[0].values()] - assert result.constrained_field >= 100 def test_vanilla_dc_with_embedded_model() -> None: @@ -195,87 +192,3 @@ class FooFactory(DataclassFactory[Foo]): # type:ignore[valid-type] foo = FooFactory.build() assert isinstance(foo, Foo) - - -def test_variable_length_tuple_generation__many_type_args() -> None: - @vanilla_dataclass - class VanillaDC: - ids: Tuple[int, ...] - - number_of_args = 3 - - class MyFactory(DataclassFactory[VanillaDC]): - __model__ = VanillaDC - - __randomize_collection_length__ = True - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert result - assert result.ids - assert len(result.ids) == number_of_args - - -def test_variable_length_dict_generation__many_type_args() -> None: - @vanilla_dataclass - class VanillaDC: - ids: Dict[str, int] - - number_of_args = 3 - - class MyFactory(DataclassFactory[VanillaDC]): - __model__ = VanillaDC - - __randomize_collection_length__ = True - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert result - assert result.ids - assert len(result.ids) == number_of_args - - -def test_variable_length_list_generation__many_type_args() -> None: - @vanilla_dataclass - class VanillaDC: - ids: List[int] - - number_of_args = 3 - - class MyFactory(DataclassFactory[VanillaDC]): - __model__ = VanillaDC - - __randomize_collection_length__ = True - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert result - assert result.ids - assert len(result.ids) == number_of_args - - -def test_variable_length_set_generation__many_type_args() -> None: - @vanilla_dataclass - class VanillaDC: - ids: Set[int] - - number_of_args = 3 - - class MyFactory(DataclassFactory[VanillaDC]): - __model__ = VanillaDC - - __randomize_collection_length__ = True - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert result - assert result.ids - assert len(result.ids) == number_of_args diff --git a/tests/test_msgspec_factory.py b/tests/test_msgspec_factory.py index 589756de..b4f13f9c 100644 --- a/tests/test_msgspec_factory.py +++ b/tests/test_msgspec_factory.py @@ -180,58 +180,6 @@ class FooFactory(MsgspecFactory[Foo]): _ = FooFactory.build() -@pytest.mark.parametrize("type_", (Set, FrozenSet, List)) -def test_variable_length(type_: Any) -> None: - class Foo(Struct): - items: type_[int] - - class FooFactory(MsgspecFactory[Foo]): - __model__ = Foo - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - foo = FooFactory.build() - assert len(foo.items) == 3 - - -def test_variable_length__dict() -> None: - class Foo(Struct): - items: Dict[int, float] - - class FooFactory(MsgspecFactory[Foo]): - __model__ = Foo - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - foo = FooFactory.build() - assert len(foo.items) == 3 - - -def test_variable_length__tuple() -> None: - class Foo(Struct): - items: Tuple[int, ...] - - class FooFactory(MsgspecFactory[Foo]): - __model__ = Foo - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - foo = FooFactory.build() - assert len(foo.items) == 3 - - def test_inheritence() -> None: class Foo(Struct): int_field: int diff --git a/tests/test_odmantic_factory.py b/tests/test_odmantic_factory.py index 5164ed87..0c443e5e 100644 --- a/tests/test_odmantic_factory.py +++ b/tests/test_odmantic_factory.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, FrozenSet, List, Set, Tuple +from typing import Any, List from uuid import UUID import bson @@ -101,59 +101,3 @@ def name(cls, id: ObjectId) -> str: return f"{cls.__faker__.name()} [{id.generation_time}]" MyFactory.build() - - -@pytest.mark.parametrize("type_", (Set, FrozenSet, List)) -def test_variable_length(type_: Any) -> None: - class MyModel(Model): # type: ignore - items: type_[bson.Int64] - - class MyFactory(OdmanticModelFactory[MyModel]): - __model__ = MyModel - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert len(result.items) == 3 - - -@pytest.mark.skipif(True, reason="test is flaky - refer issue #362") -def test_variable_length__dict() -> None: - class MyModel(Model): # type: ignore - items: Dict[bson.Int64, UUID] - - class MyFactory(OdmanticModelFactory[MyModel]): - __model__ = MyModel - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert len(result.items) == 3 - - -def test_variable_length__tuple() -> None: - class MyModel(Model): # type: ignore - items: Tuple[bson.Int64, ...] - - class MyFactory(OdmanticModelFactory[MyModel]): - __model__ = MyModel - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert len(result.items) == 3 diff --git a/tests/test_passing_build_args_to_child_factories.py b/tests/test_passing_build_args_to_child_factories.py index 0bc97634..50cfeb2f 100644 --- a/tests/test_passing_build_args_to_child_factories.py +++ b/tests/test_passing_build_args_to_child_factories.py @@ -1,6 +1,5 @@ -from typing import Any, Dict, FrozenSet, List, Mapping, Optional, Set, Tuple +from typing import List, Mapping, Optional -import pytest from pydantic import BaseModel from polyfactory.factories.pydantic_factory import ModelFactory @@ -177,67 +176,3 @@ class DFactory(ModelFactory): build_result = DFactory.build(factory_use_construct=False, **{"c": {"b": {"a": {"name": "test"}}}}) assert build_result assert build_result.c.b.a.name == "test" - - -@pytest.mark.parametrize("type_", (Set, FrozenSet, List)) -def test_variable_length__collection(type_: Any) -> None: - class Vanilla(BaseModel): - ids: type_[int] - - number_of_args = 3 - - class MyFactory(ModelFactory[Vanilla]): - __model__ = Vanilla - - __randomize_collection_length__ = True - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert result - assert result.ids - assert len(result.ids) == number_of_args - - -def test_variable_length__tuple() -> None: - class Chocolate(BaseModel): - val: float - - class Vanilla(BaseModel): - ids: Tuple[Chocolate, ...] - - number_of_args = 3 - - class MyFactory(ModelFactory[Vanilla]): - __model__ = Vanilla - - __randomize_collection_length__ = True - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert result - assert result.ids - assert len(result.ids) == number_of_args - - -def test_variable_length__dict() -> None: - class Vanilla(BaseModel): - ids: Dict[str, float] - - number_of_args = 3 - - class MyFactory(ModelFactory[Vanilla]): - __model__ = Vanilla - - __randomize_collection_length__ = True - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert result - assert result.ids - assert len(result.ids) == number_of_args diff --git a/tests/test_typeddict_factory.py b/tests/test_typeddict_factory.py index 1964b064..9d3a3138 100644 --- a/tests/test_typeddict_factory.py +++ b/tests/test_typeddict_factory.py @@ -44,21 +44,3 @@ class MyFactory(ModelFactory[MyModel]): assert result.td["name"] assert result.td["list_field"][0] assert type(result.td["int_field"]) in (type(None), int) - - -def test_variable_length() -> None: - class MyModel(BaseModel): - items: List[int] - - class MyFactory(ModelFactory[MyModel]): - __model__ = MyModel - - __randomize_collection_length__ = True - number_of_args = 3 - - __min_collection_length__ = number_of_args - __max_collection_length__ = number_of_args - - result = MyFactory.build() - - assert len(result.items) == 3 From 8a3ac1ff322943c4d8b42f7016693274ac9b893a Mon Sep 17 00:00:00 2001 From: guacs <126393040+guacs@users.noreply.github.com> Date: Sun, 15 Oct 2023 00:12:41 +0530 Subject: [PATCH 17/37] refactor: refactor the msgspec factory to use the fields API (#409) --- polyfactory/factories/msgspec_factory.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/polyfactory/factories/msgspec_factory.py b/polyfactory/factories/msgspec_factory.py index 4d4cf1d0..d97e2225 100644 --- a/polyfactory/factories/msgspec_factory.py +++ b/polyfactory/factories/msgspec_factory.py @@ -7,7 +7,6 @@ Callable, Generic, TypeVar, - cast, ) from typing_extensions import get_type_hints @@ -23,7 +22,7 @@ try: import msgspec - from msgspec import inspect + from msgspec.structs import fields except ImportError as e: msg = "msgspec is not installed" raise MissingDependencyException(msg) from e @@ -56,12 +55,11 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: @classmethod def get_model_fields(cls) -> list[FieldMeta]: - type_info = cast(inspect.StructType, inspect.type_info(cls.__model__)) - fields_meta: list[FieldMeta] = [] - for field in type_info.fields: - annotation = get_type_hints(cls.__model__, include_extras=True)[field.name] + type_hints = get_type_hints(cls.__model__, include_extras=True) + for field in fields(cls.__model__): + annotation = type_hints[field.name] if field.default is not msgspec.NODEFAULT: default_value = field.default elif field.default_factory is not msgspec.NODEFAULT: From 5baf00afafd8a4f4ffb6030ec91834fcd6e019c6 Mon Sep 17 00:00:00 2001 From: guacs <126393040+guacs@users.noreply.github.com> Date: Mon, 16 Oct 2023 08:51:54 +0530 Subject: [PATCH 18/37] chore: prepare for releasing v2.10 (#410) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index af82e9a9..711f6364 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ classifiers = [ "Typing :: Typed", ] name = "polyfactory" -version = "2.9.0" +version = "2.10.0" description = "Mock data generation factories" readme = "docs/PYPI_README.md" license = {text = "MIT"} From 7570acd4bd1aa1c5ace7a082899cc9feb8e0951a Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 01:19:46 +0000 Subject: [PATCH 19/37] feat(type-coverage-gen): Initial implementation of type coverage generation --- .pre-commit-config.yaml | 2 +- polyfactory/factories/base.py | 5 +++++ polyfactory/utils/helpers.py | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99cf36fa..71f6d84a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: "3.11" + python: "3.10" repos: - repo: https://github.com/compilerla/conventional-pre-commit rev: v2.4.0 diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 62d224f5..fa808af6 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -27,6 +27,7 @@ from types import NoneType except ImportError: NoneType = type(None) # type: ignore[misc,assignment] + from typing import ( TYPE_CHECKING, Any, @@ -67,6 +68,10 @@ is_union, ) from polyfactory.value_generators.complex_types import handle_collection_type +from polyfactory.utils.helpers import flatten_annotation, unwrap_annotation, unwrap_args, unwrap_optional +from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage +from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union +from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage from polyfactory.value_generators.constrained_collections import ( handle_constrained_collection, handle_constrained_mapping, diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index db174354..7a024f93 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -2,6 +2,7 @@ import sys from typing import TYPE_CHECKING, Any, Mapping +from types import NoneType try: from types import NoneType From 87df742ebac9f33ecf4775c1a3c77692a3726442 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 27 Sep 2023 01:54:24 +0000 Subject: [PATCH 20/37] fix: Update NoneType importing for older python versions --- polyfactory/utils/helpers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index 7a024f93..72fece12 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -3,6 +3,12 @@ import sys from typing import TYPE_CHECKING, Any, Mapping from types import NoneType +from typing import TYPE_CHECKING, Any + +try: + from types import NoneType +except ImportError: + NoneType = type(None) # type: ignore[misc,assignment] try: from types import NoneType From 811658b120f964becc2eebe7f7aee906414f9a13 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Mon, 16 Oct 2023 03:37:55 +0000 Subject: [PATCH 21/37] fix: Make CoverageContainer generic --- polyfactory/utils/model_coverage.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/polyfactory/utils/model_coverage.py b/polyfactory/utils/model_coverage.py index e73fd1b2..dad55225 100644 --- a/polyfactory/utils/model_coverage.py +++ b/polyfactory/utils/model_coverage.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence -from typing import AbstractSet, Any, Generic, Set, TypeVar +from typing import AbstractSet, Any, Generic, Set, TypeVar, cast from typing_extensions import ParamSpec @@ -17,22 +17,25 @@ def is_done(self) -> bool: ... -class CoverageContainer(CoverageContainerBase): - def __init__(self, instances: Iterable[Any]) -> None: +T = TypeVar("T") + + +class CoverageContainer(CoverageContainerBase, Generic[T]): + def __init__(self, instances: Iterable[T]) -> None: self._pos = 0 self._instances = list(instances) if not self._instances: msg = "CoverageContainer must have at least one instance" raise ValueError(msg) - def next_value(self) -> Any: + def next_value(self) -> T: value = self._instances[self._pos % len(self._instances)] if isinstance(value, CoverageContainerBase): result = value.next_value() if value.is_done(): # Only move onto the next instance if the sub-container is done self._pos += 1 - return result + return cast(T, result) self._pos += 1 return value @@ -44,7 +47,6 @@ def __repr__(self) -> str: return f"CoverageContainer(instances={self._instances}, is_done={self.is_done()})" -T = TypeVar("T") P = ParamSpec("P") From bac19716f2d69b2e9cca125020e5e2bd599a6045 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Mon, 16 Oct 2023 03:48:28 +0000 Subject: [PATCH 22/37] fix: linting and rebase issues --- polyfactory/factories/base.py | 17 ++++++++--------- polyfactory/utils/helpers.py | 6 ------ 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index fa808af6..8e8ec015 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -27,7 +27,7 @@ from types import NoneType except ImportError: NoneType = type(None) # type: ignore[misc,assignment] - + from typing import ( TYPE_CHECKING, Any, @@ -54,11 +54,14 @@ ) from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use -from polyfactory.utils.helpers import flatten_annotation, unwrap_annotation, unwrap_args, unwrap_optional +from polyfactory.utils.helpers import ( + flatten_annotation, + get_collection_type, + unwrap_annotation, + unwrap_args, + unwrap_optional, +) from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage -from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union -from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage -from polyfactory.utils.helpers import get_collection_type, unwrap_annotation, unwrap_args, unwrap_optional from polyfactory.utils.predicates import ( get_type_origin, is_any, @@ -67,10 +70,6 @@ is_safe_subclass, is_union, ) -from polyfactory.value_generators.complex_types import handle_collection_type -from polyfactory.utils.helpers import flatten_annotation, unwrap_annotation, unwrap_args, unwrap_optional -from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage -from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage from polyfactory.value_generators.constrained_collections import ( handle_constrained_collection, diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index 72fece12..7e1dd21a 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -2,18 +2,12 @@ import sys from typing import TYPE_CHECKING, Any, Mapping -from types import NoneType -from typing import TYPE_CHECKING, Any try: from types import NoneType except ImportError: NoneType = type(None) # type: ignore[misc,assignment] -try: - from types import NoneType -except ImportError: - NoneType = type(None) # type: ignore[misc,assignment] from typing_extensions import get_args, get_origin From deb72a18ae986ab16863857bc516dbaa0ba60380 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Mon, 16 Oct 2023 04:14:25 +0000 Subject: [PATCH 23/37] fix: revert pre-commit conf change --- .pre-commit-config.yaml | 2 +- polyfactory/factories/base.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71f6d84a..99cf36fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: "3.10" + python: "3.11" repos: - repo: https://github.com/compilerla/conventional-pre-commit rev: v2.4.0 diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 8e8ec015..cd4e760b 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -85,7 +85,11 @@ from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes from polyfactory.value_generators.constrained_url import handle_constrained_url from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid -from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string +from polyfactory.value_generators.primitives import ( + create_random_boolean, + create_random_bytes, + create_random_string, +) if TYPE_CHECKING: from typing_extensions import TypeGuard From 85e2525c760477995bce57355f5bbfd805fb098d Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 18 Oct 2023 22:30:37 +0000 Subject: [PATCH 24/37] docs(type-coverage-gen): Add docs for coverage gen --- docs/examples/model_coverage/__init__.py | 0 .../examples/model_coverage/test_example_1.py | 49 +++++++++++++++++++ .../examples/model_coverage/test_example_2.py | 38 ++++++++++++++ docs/usage/model_coverage.rst | 29 +++++++++++ 4 files changed, 116 insertions(+) create mode 100644 docs/examples/model_coverage/__init__.py create mode 100644 docs/examples/model_coverage/test_example_1.py create mode 100644 docs/examples/model_coverage/test_example_2.py create mode 100644 docs/usage/model_coverage.rst diff --git a/docs/examples/model_coverage/__init__.py b/docs/examples/model_coverage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docs/examples/model_coverage/test_example_1.py b/docs/examples/model_coverage/test_example_1.py new file mode 100644 index 00000000..3246bdbc --- /dev/null +++ b/docs/examples/model_coverage/test_example_1.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from polyfactory.factories.dataclass_factory import DataclassFactory + + +@dataclass +class Car: + model: str + + +@dataclass +class Boat: + can_float: bool + + +@dataclass +class Profile: + age: int + favourite_color: Literal["red", "green", "blue"] + vehicle: Car | Boat + + +class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + + +profiles = list(ProfileFactory.coverage()) + +# >>> print(profiles) +[ + Profile( + age=9325, + favourite_color="red", + vehicle=Car(model="hrxarraoDbdkBnpxMEiG"), + ), + Profile( + age=6840, + favourite_color="green", + vehicle=Boat(can_float=False), + ), + Profile( + age=4769, + favourite_color="blue", + vehicle=Car(model="hrxarraoDbdkBnpxMEiG"), + ), +] diff --git a/docs/examples/model_coverage/test_example_2.py b/docs/examples/model_coverage/test_example_2.py new file mode 100644 index 00000000..7426a606 --- /dev/null +++ b/docs/examples/model_coverage/test_example_2.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from polyfactory.factories.dataclass_factory import DataclassFactory + + +@dataclass +class Car: + model: str + + +@dataclass +class Boat: + can_float: bool + + +@dataclass +class Profile: + age: int + favourite_color: Literal["red", "green", "blue"] + vehicle: Car | Boat + + +@dataclass +class SocialGroup: + members: list[Profile] + + +class SocialGroupFactory(DataclassFactory[SocialGroup]): + __model__ = SocialGroup + + +group = list(SocialGroupFactory.coverage()) + +# >>> print(group) +# >>> SocialGroup(members=[Profile(...), Profile(...), Profile(...)]) diff --git a/docs/usage/model_coverage.rst b/docs/usage/model_coverage.rst new file mode 100644 index 00000000..ba55974b --- /dev/null +++ b/docs/usage/model_coverage.rst @@ -0,0 +1,29 @@ +Model coverage generation +========================= + +The `BaseFactory.coverage()` function is an alternative approach to `BaseFactory.batch()`, where the examples that are generated attempt to provide full coverage of all the forms a model can take with the minimum number of instances. For example: + +.. literalinclude:: /examples/model_coverage/test_example_1.py + :caption: Defining a factory and generating examples with coverage + :language: python + +As you can see in the above example, the `Profile` model has 3 options for `favourite_color`, and 2 options for `vehicle`. In the output you can expect to see instances of `Profile` that have each of these options. The largest variance dictates the length of the output, in this case `favourite_color` has the most, at 3 options, so expect to see 3 `Profile` instances. + + +.. note:: + Notice that the same `Car` instance is used in the first and final generated example. When the coverage examples for a field are exhausted before another field, values for that field are re-used. + +Notes on collection types +------------------------- + +When generating coverage for models with fields that are collections, in particular collections that contain sub-models, the contents of the collection will be the all coverage examples for that sub-model. For example: + +.. literalinclude:: /examples/model_coverage/test_example_2.py + :caption: Coverage output for the SocialGroup model + :language: python + +Known Limitations +----------------- + +- Recursive models will cause an error: `RecursionError: maximum recursion depth exceeded`. +- `__min_collection_length__` and `__max_collection_length__` are currently ignored in coverage generation. From 88923de4c32a7d50cf8d3f71a6f61e12d16dd0a6 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Wed, 18 Oct 2023 22:45:04 +0000 Subject: [PATCH 25/37] docs: Fix formatting in coverage docs --- docs/usage/index.rst | 1 + docs/usage/model_coverage.rst | 10 +++++----- polyfactory/factories/base.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/usage/index.rst b/docs/usage/index.rst index 7569619f..32a49a27 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -12,3 +12,4 @@ Usage Guide decorators fixtures handling_custom_types + model_coverage diff --git a/docs/usage/model_coverage.rst b/docs/usage/model_coverage.rst index ba55974b..ecb16bf4 100644 --- a/docs/usage/model_coverage.rst +++ b/docs/usage/model_coverage.rst @@ -1,17 +1,17 @@ Model coverage generation ========================= -The `BaseFactory.coverage()` function is an alternative approach to `BaseFactory.batch()`, where the examples that are generated attempt to provide full coverage of all the forms a model can take with the minimum number of instances. For example: +The ``BaseFactory.coverage()`` function is an alternative approach to ``BaseFactory.batch()``, where the examples that are generated attempt to provide full coverage of all the forms a model can take with the minimum number of instances. For example: .. literalinclude:: /examples/model_coverage/test_example_1.py :caption: Defining a factory and generating examples with coverage :language: python -As you can see in the above example, the `Profile` model has 3 options for `favourite_color`, and 2 options for `vehicle`. In the output you can expect to see instances of `Profile` that have each of these options. The largest variance dictates the length of the output, in this case `favourite_color` has the most, at 3 options, so expect to see 3 `Profile` instances. +As you can see in the above example, the ````Profile```` model has 3 options for ``favourite_color``, and 2 options for ``vehicle``. In the output you can expect to see instances of ``Profile`` that have each of these options. The largest variance dictates the length of the output, in this case ``favourite_color`` has the most, at 3 options, so expect to see 3 ``Profile`` instances. .. note:: - Notice that the same `Car` instance is used in the first and final generated example. When the coverage examples for a field are exhausted before another field, values for that field are re-used. + Notice that the same ``Car`` instance is used in the first and final generated example. When the coverage examples for a field are exhausted before another field, values for that field are re-used. Notes on collection types ------------------------- @@ -25,5 +25,5 @@ When generating coverage for models with fields that are collections, in particu Known Limitations ----------------- -- Recursive models will cause an error: `RecursionError: maximum recursion depth exceeded`. -- `__min_collection_length__` and `__max_collection_length__` are currently ignored in coverage generation. +- Recursive models will cause an error: ``RecursionError: maximum recursion depth exceeded``. +- ``__min_collection_length__`` and ``__max_collection_length__`` are currently ignored in coverage generation. diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index cd4e760b..f3c39b0b 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -667,7 +667,7 @@ def get_field_value_coverage( # noqa: C901 :param field_meta: FieldMeta instance. :param field_build_parameters: Any build parameters passed to the factory as kwarg values. - :returns: An arbitrary value. + :returns: An iterable of values. """ if cls.is_ignored_type(field_meta.annotation): From 2df1a26ac64a101d6656518c423d84d5875fa800 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Thu, 2 Nov 2023 01:21:49 +0000 Subject: [PATCH 26/37] docs: Move profile coverage exmaple into test func --- .../examples/model_coverage/test_example_1.py | 29 ++++++------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/docs/examples/model_coverage/test_example_1.py b/docs/examples/model_coverage/test_example_1.py index 3246bdbc..ced96f4a 100644 --- a/docs/examples/model_coverage/test_example_1.py +++ b/docs/examples/model_coverage/test_example_1.py @@ -27,23 +27,12 @@ class ProfileFactory(DataclassFactory[Profile]): __model__ = Profile -profiles = list(ProfileFactory.coverage()) - -# >>> print(profiles) -[ - Profile( - age=9325, - favourite_color="red", - vehicle=Car(model="hrxarraoDbdkBnpxMEiG"), - ), - Profile( - age=6840, - favourite_color="green", - vehicle=Boat(can_float=False), - ), - Profile( - age=4769, - favourite_color="blue", - vehicle=Car(model="hrxarraoDbdkBnpxMEiG"), - ), -] +def test_profile_coverage() -> None: + profiles = list(ProfileFactory.coverage()) + + assert profiles[0].favourite_color == "red" + assert isinstance(profiles[0].vehicle, Car) + assert profiles[1].favourite_color == "green" + assert isinstance(profiles[1].vehicle, Boat) + assert profiles[2].favourite_color == "blue" + assert isinstance(profiles[2].vehicle, Car) From f989775b3a56415fce4d3dab6d92c0315ff757a2 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Thu, 2 Nov 2023 01:28:30 +0000 Subject: [PATCH 27/37] docs: Update social group example to use test func --- docs/examples/model_coverage/test_example_2.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/examples/model_coverage/test_example_2.py b/docs/examples/model_coverage/test_example_2.py index 7426a606..b5f4956d 100644 --- a/docs/examples/model_coverage/test_example_2.py +++ b/docs/examples/model_coverage/test_example_2.py @@ -32,7 +32,20 @@ class SocialGroupFactory(DataclassFactory[SocialGroup]): __model__ = SocialGroup -group = list(SocialGroupFactory.coverage()) +def test_social_group_coverage() -> None: + groups = list(SocialGroupFactory.coverage()) + assert len(groups) == 1 + + members = groups[0].members + assert len(members) == 3 + + assert members[0].favourite_color == "red" + assert isinstance(members[0].vehicle, Car) + assert members[1].favourite_color == "green" + assert isinstance(members[1].vehicle, Boat) + assert members[2].favourite_color == "blue" + assert isinstance(members[2].vehicle, Car) + # >>> print(group) # >>> SocialGroup(members=[Profile(...), Profile(...), Profile(...)]) From b337bc08d871885ee53492ed2aaeccd7a329d147 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Fri, 3 Nov 2023 01:04:53 +0000 Subject: [PATCH 28/37] fix: Address review comments - Add missing docstrings - Move error handling around CoverageContainerCallable to inside - Formatting issue in documentation --- .../examples/model_coverage/test_example_2.py | 4 --- docs/usage/model_coverage.rst | 2 +- polyfactory/factories/base.py | 6 ++--- polyfactory/utils/helpers.py | 4 ++- polyfactory/utils/model_coverage.py | 25 ++++++++++++++++++- 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/docs/examples/model_coverage/test_example_2.py b/docs/examples/model_coverage/test_example_2.py index b5f4956d..23acdc9b 100644 --- a/docs/examples/model_coverage/test_example_2.py +++ b/docs/examples/model_coverage/test_example_2.py @@ -45,7 +45,3 @@ def test_social_group_coverage() -> None: assert isinstance(members[1].vehicle, Boat) assert members[2].favourite_color == "blue" assert isinstance(members[2].vehicle, Car) - - -# >>> print(group) -# >>> SocialGroup(members=[Profile(...), Profile(...), Profile(...)]) diff --git a/docs/usage/model_coverage.rst b/docs/usage/model_coverage.rst index ecb16bf4..753dedde 100644 --- a/docs/usage/model_coverage.rst +++ b/docs/usage/model_coverage.rst @@ -7,7 +7,7 @@ The ``BaseFactory.coverage()`` function is an alternative approach to ``BaseFact :caption: Defining a factory and generating examples with coverage :language: python -As you can see in the above example, the ````Profile```` model has 3 options for ``favourite_color``, and 2 options for ``vehicle``. In the output you can expect to see instances of ``Profile`` that have each of these options. The largest variance dictates the length of the output, in this case ``favourite_color`` has the most, at 3 options, so expect to see 3 ``Profile`` instances. +As you can see in the above example, the ``Profile`` model has 3 options for ``favourite_color``, and 2 options for ``vehicle``. In the output you can expect to see instances of ``Profile`` that have each of these options. The largest variance dictates the length of the output, in this case ``favourite_color`` has the most, at 3 options, so expect to see 3 ``Profile`` instances. .. note:: diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 56b0bbc2..9acab38c 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from abc import ABC, abstractmethod from collections import Counter, abc, deque from contextlib import suppress @@ -668,7 +669,7 @@ def get_field_value_coverage( # noqa: C901 cls, field_meta: FieldMeta, field_build_parameters: Any | None = None, - ) -> abc.Iterable[Any]: + ) -> typing.Iterable[Any]: """Return a field value on the subclass if existing, otherwise returns a mock value. :param field_meta: FieldMeta instance. @@ -716,8 +717,7 @@ def get_field_value_coverage( # noqa: C901 elif callable(unwrapped_annotation): # if value is a callable we can try to naively call it. # this will work for callables that do not require any parameters passed - with suppress(Exception): - yield CoverageContainerCallable(unwrapped_annotation) + yield CoverageContainerCallable(unwrapped_annotation) else: msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." raise ParameterException( diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index db174354..da34a8b1 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -84,7 +84,9 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any: def flatten_annotation(annotation: Any) -> list[Any]: """Flattens an annotation. + :param annotation: A type annotation. + :returns: The flattened annotations. """ flat = [] @@ -92,7 +94,7 @@ def flatten_annotation(annotation: Any) -> list[Any]: flat.extend(flatten_annotation(unwrap_new_type(annotation))) elif is_optional(annotation): flat.append(NoneType) - flat.extend(flatten_annotation(next(arg for arg in get_args(annotation) if arg not in (NoneType, None)))) + flat.extend(flatten_annotation(arg) for arg in get_args(annotation) if arg not in (NoneType, None)) elif is_annotated(annotation): flat.extend(flatten_annotation(get_args(annotation)[0])) elif is_union(annotation): diff --git a/polyfactory/utils/model_coverage.py b/polyfactory/utils/model_coverage.py index dad55225..1663e4d5 100644 --- a/polyfactory/utils/model_coverage.py +++ b/polyfactory/utils/model_coverage.py @@ -6,14 +6,24 @@ from typing_extensions import ParamSpec +from polyfactory.exceptions import ParameterException + class CoverageContainerBase(ABC): + """Base class for coverage container implementations. + + A coverage container is a wrapper providing values for a particular field. Coverage containers return field values and + track a "done" state to indicate that all coverage examples have been generated. + """ + @abstractmethod def next_value(self) -> Any: + """Provide the next value""" ... @abstractmethod def is_done(self) -> bool: + """Indicate if this container has provided every coverage example it has""" ... @@ -21,6 +31,15 @@ def is_done(self) -> bool: class CoverageContainer(CoverageContainerBase, Generic[T]): + """A coverage container that wraps a collection of values. + + When that calling `next_value()` a greater number of times than the length of the given collection will cause duplicate + examples to be returned (wraps around). + + If there are any coverage containers within the given collection, the values from those containers are essentially merged + into the parent container. + """ + def __init__(self, instances: Iterable[T]) -> None: self._pos = 0 self._instances = list(instances) @@ -57,7 +76,11 @@ def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> N self._kwargs = kwargs def next_value(self) -> T: - return self._func(*self._args, **self._kwargs) + try: + return self._func(*self._args, **self._kwargs) + except Exception as e: # noqa: BLE001 + msg = f"Unsupported type: {self._func!r}\n\nEither extend the providers map or add a factory function for this type." + raise ParameterException(msg) from e def is_done(self) -> bool: return True From be5e71228c5988baf493dd241a97b859819e7d26 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Fri, 3 Nov 2023 01:23:02 +0000 Subject: [PATCH 29/37] test: Remove 3.10 requirement for coverage tests --- tests/test_type_coverage_generation.py | 85 +++++++++++++------------- 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index cc3a6380..036fc038 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -1,27 +1,27 @@ +# ruff: noqa: UP007 from __future__ import annotations -import sys -from dataclasses import dataclass +from dataclasses import dataclass, make_dataclass from datetime import date -from typing import Literal, Sequence +from typing import FrozenSet, List, Literal, Set, Union from uuid import UUID import pytest from typing_extensions import TypedDict from polyfactory.decorators import post_generated +from polyfactory.exceptions import ParameterException from polyfactory.factories.dataclass_factory import DataclassFactory from polyfactory.factories.typed_dict_factory import TypedDictFactory -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_count() -> None: @dataclass class Profile: name: str - high_score: int | float + high_score: Union[int, float] dob: date - data: str | date | int | float + data: Union[str, date, int, float] class ProfileFactory(DataclassFactory[Profile]): __model__ = Profile @@ -34,11 +34,10 @@ class ProfileFactory(DataclassFactory[Profile]): assert isinstance(result, Profile) -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_tuple() -> None: @dataclass class Tuple: - tuple_: tuple[int | str, tuple[int | float, int]] + tuple_: tuple[Union[int, str], tuple[Union[int, float], int]] class TupleFactory(DataclassFactory[Tuple]): __model__ = Tuple @@ -54,16 +53,14 @@ class TupleFactory(DataclassFactory[Tuple]): assert isinstance(a1, str) and isinstance(b1, float) and isinstance(c1, int) -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") -def test_coverage_collection() -> None: - @dataclass - class Collective: - set_: set[int | str] - list_: list[int | str] - frozenset_: frozenset[int | str] - sequence_: Sequence[int | str] +@pytest.mark.parametrize( + "collection_annotation", + (Set[Union[int, str]], List[Union[int, str]], FrozenSet[Union[int, str]]), +) +def test_coverage_collection(collection_annotation: type) -> None: + Collective = make_dataclass("Collective", [("collection", collection_annotation)]) - class CollectiveFactory(DataclassFactory[Collective]): + class CollectiveFactory(DataclassFactory[Collective]): # type: ignore __model__ = Collective results = list(CollectiveFactory.coverage()) @@ -72,22 +69,14 @@ class CollectiveFactory(DataclassFactory[Collective]): result = results[0] - assert len(result.set_) == 2 - assert len(result.list_) == 2 - assert len(result.frozenset_) == 2 - assert len(result.sequence_) == 2 + collection = result.collection # type: ignore - v0, v1 = result.set_ - assert {type(v0), type(v1)} == {int, str} - v0, v1 = result.list_ - assert {type(v0), type(v1)} == {int, str} - v0, v1 = result.frozenset_ - assert {type(v0), type(v1)} == {int, str} - v0, v1 = result.sequence_ + assert len(collection) == 2 + + v0, v1 = collection assert {type(v0), type(v1)} == {int, str} -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_literal() -> None: @dataclass class Literally: @@ -106,13 +95,12 @@ class LiterallyFactory(DataclassFactory[Literally]): assert results[3].literal == 2 -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_dict() -> None: @dataclass class Thesaurus: dict_simple: dict[str, int] - dict_more_key_types: dict[str | int | float, int | str] - dict_more_value_types: dict[str, int | str] + dict_more_key_types: dict[Union[str, int, float], Union[int, str]] + dict_more_value_types: dict[str, Union[int, str]] class ThesaurusFactory(DataclassFactory[Thesaurus]): __model__ = Thesaurus @@ -123,11 +111,10 @@ class ThesaurusFactory(DataclassFactory[Thesaurus]): @pytest.mark.skip(reason="Does not support recursive types yet.") -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_recursive() -> None: @dataclass class Recursive: - r: Recursive | None + r: Union[Recursive, None] class RecursiveFactory(DataclassFactory[Recursive]): __model__ = Recursive @@ -136,13 +123,12 @@ class RecursiveFactory(DataclassFactory[Recursive]): assert len(results) == 2 -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_typed_dict() -> None: class TypedThesaurus(TypedDict): number: int string: str - union: int | str - collection: list[int | str] + union: Union[int, str] + collection: list[Union[int, str]] class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): __model__ = TypedThesaurus @@ -156,13 +142,12 @@ class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): assert result.keys() == example.keys() -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_typed_dict_field() -> None: class TypedThesaurus(TypedDict): number: int string: str - union: int | str - collection: list[int | str] + union: Union[int, str] + collection: list[Union[int, str]] class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): __model__ = TypedThesaurus @@ -177,12 +162,11 @@ class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): assert result.keys() == example.keys() -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_values_unique() -> None: @dataclass class Unique: uuid: UUID - data: int | str + data: Union[int, str] class UniqueFactory(DataclassFactory[Unique]): __model__ = Unique @@ -193,7 +177,6 @@ class UniqueFactory(DataclassFactory[Unique]): assert results[0].uuid != results[1].uuid -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_coverage_post_generated() -> None: @dataclass class Model: @@ -212,3 +195,19 @@ def i(cls, j: int) -> int: assert len(results) == 1 assert results[0].i == results[0].j + 10 + + +def test_coverage_parameter_exception() -> None: + @dataclass + class Model: + class CustomInt: + def __init__(self, value: int) -> None: + self.value = value + + i: CustomInt + + class Factory(DataclassFactory[Model]): + __model__ = Model + + with pytest.raises(ParameterException): + list(Factory.coverage()) From eb075e0348d0bf290f594e181febef2ee46b2524 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Fri, 3 Nov 2023 01:30:09 +0000 Subject: [PATCH 30/37] test: Move CustomInt definition outside of test --- tests/test_type_coverage_generation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index 036fc038..c9fb64f9 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -197,13 +197,14 @@ def i(cls, j: int) -> int: assert results[0].i == results[0].j + 10 +class CustomInt: + def __init__(self, value: int) -> None: + self.value = value + + def test_coverage_parameter_exception() -> None: @dataclass class Model: - class CustomInt: - def __init__(self, value: int) -> None: - self.value = value - i: CustomInt class Factory(DataclassFactory[Model]): From 52995357e061dbb744af3c882d3fd193d0de73dd Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Fri, 3 Nov 2023 02:19:57 +0000 Subject: [PATCH 31/37] test: disable ruff UP006 in test file --- tests/test_type_coverage_generation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index c9fb64f9..bbba3d0e 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -1,4 +1,4 @@ -# ruff: noqa: UP007 +# ruff: noqa: UP007, UP006 from __future__ import annotations from dataclasses import dataclass, make_dataclass @@ -128,7 +128,7 @@ class TypedThesaurus(TypedDict): number: int string: str union: Union[int, str] - collection: list[Union[int, str]] + collection: List[Union[int, str]] class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): __model__ = TypedThesaurus @@ -147,7 +147,7 @@ class TypedThesaurus(TypedDict): number: int string: str union: Union[int, str] - collection: list[Union[int, str]] + collection: List[Union[int, str]] class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): __model__ = TypedThesaurus From 93f865849aafdf46df940ab0e9a020a700492d6d Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Fri, 3 Nov 2023 02:26:36 +0000 Subject: [PATCH 32/37] test: fix social group test in docs example --- docs/examples/model_coverage/test_example_2.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/examples/model_coverage/test_example_2.py b/docs/examples/model_coverage/test_example_2.py index 23acdc9b..1e38efdf 100644 --- a/docs/examples/model_coverage/test_example_2.py +++ b/docs/examples/model_coverage/test_example_2.py @@ -34,14 +34,14 @@ class SocialGroupFactory(DataclassFactory[SocialGroup]): def test_social_group_coverage() -> None: groups = list(SocialGroupFactory.coverage()) - assert len(groups) == 1 + assert len(groups) == 3 - members = groups[0].members - assert len(members) == 3 + for group in groups: + assert len(group.members) == 1 - assert members[0].favourite_color == "red" - assert isinstance(members[0].vehicle, Car) - assert members[1].favourite_color == "green" - assert isinstance(members[1].vehicle, Boat) - assert members[2].favourite_color == "blue" - assert isinstance(members[2].vehicle, Car) + assert groups[0].members[0].favourite_color == "red" + assert isinstance(groups[0].members[0].vehicle, Car) + assert groups[1].members[1].favourite_color == "green" + assert isinstance(groups[1].members[0].vehicle, Boat) + assert groups[2].members[2].favourite_color == "blue" + assert isinstance(groups[2].members[0].vehicle, Car) From d70c52646b0cab14ea61934737da9577176a79a7 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Fri, 3 Nov 2023 02:29:44 +0000 Subject: [PATCH 33/37] test: fix social group test in doc example --- docs/examples/model_coverage/test_example_2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/model_coverage/test_example_2.py b/docs/examples/model_coverage/test_example_2.py index 1e38efdf..bf67f959 100644 --- a/docs/examples/model_coverage/test_example_2.py +++ b/docs/examples/model_coverage/test_example_2.py @@ -41,7 +41,7 @@ def test_social_group_coverage() -> None: assert groups[0].members[0].favourite_color == "red" assert isinstance(groups[0].members[0].vehicle, Car) - assert groups[1].members[1].favourite_color == "green" + assert groups[1].members[0].favourite_color == "green" assert isinstance(groups[1].members[0].vehicle, Boat) - assert groups[2].members[2].favourite_color == "blue" + assert groups[2].members[0].favourite_color == "blue" assert isinstance(groups[2].members[0].vehicle, Car) From d6e2ce8371c20985474675c8c6d7cda0afece193 Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Fri, 3 Nov 2023 02:35:07 +0000 Subject: [PATCH 34/37] test: Change hint dict to Dict in coverage test --- tests/test_type_coverage_generation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index bbba3d0e..27a94ce6 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, make_dataclass from datetime import date -from typing import FrozenSet, List, Literal, Set, Union +from typing import Dict, FrozenSet, List, Literal, Set, Union from uuid import UUID import pytest @@ -98,9 +98,9 @@ class LiterallyFactory(DataclassFactory[Literally]): def test_coverage_dict() -> None: @dataclass class Thesaurus: - dict_simple: dict[str, int] - dict_more_key_types: dict[Union[str, int, float], Union[int, str]] - dict_more_value_types: dict[str, Union[int, str]] + dict_simple: Dict[str, int] + dict_more_key_types: Dict[Union[str, int, float], Union[int, str]] + dict_more_value_types: Dict[str, Union[int, str]] class ThesaurusFactory(DataclassFactory[Thesaurus]): __model__ = Thesaurus From d0cf706dd561b6be49900047f28cc0c54679ebca Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Fri, 3 Nov 2023 02:44:02 +0000 Subject: [PATCH 35/37] test: Fix tuple annotation in coverage tests --- tests/test_type_coverage_generation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index 27a94ce6..047c9954 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, make_dataclass from datetime import date -from typing import Dict, FrozenSet, List, Literal, Set, Union +from typing import Dict, FrozenSet, List, Literal, Set, Tuple, Union from uuid import UUID import pytest @@ -36,11 +36,11 @@ class ProfileFactory(DataclassFactory[Profile]): def test_coverage_tuple() -> None: @dataclass - class Tuple: - tuple_: tuple[Union[int, str], tuple[Union[int, float], int]] + class Pair: + tuple_: Tuple[Union[int, str], Tuple[Union[int, float], int]] - class TupleFactory(DataclassFactory[Tuple]): - __model__ = Tuple + class TupleFactory(DataclassFactory[Pair]): + __model__ = Pair results = list(TupleFactory.coverage()) From ba355e507b8bc6bd5760bd08b3e724401019da8e Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Mon, 6 Nov 2023 03:17:32 +0000 Subject: [PATCH 36/37] chore: fix formatting in docstring --- polyfactory/utils/model_coverage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polyfactory/utils/model_coverage.py b/polyfactory/utils/model_coverage.py index 1663e4d5..02b31812 100644 --- a/polyfactory/utils/model_coverage.py +++ b/polyfactory/utils/model_coverage.py @@ -33,7 +33,7 @@ def is_done(self) -> bool: class CoverageContainer(CoverageContainerBase, Generic[T]): """A coverage container that wraps a collection of values. - When that calling `next_value()` a greater number of times than the length of the given collection will cause duplicate + When calling ``next_value()`` a greater number of times than the length of the given collection will cause duplicate examples to be returned (wraps around). If there are any coverage containers within the given collection, the values from those containers are essentially merged From 9c84551caa734923ff2b3cc4789c44de4817f79a Mon Sep 17 00:00:00 2001 From: sam-or <67568065+sam-or@users.noreply.github.com> Date: Sun, 12 Nov 2023 06:05:57 +0000 Subject: [PATCH 37/37] chore: Add docstring to CoverageContainerCallable --- polyfactory/utils/model_coverage.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/polyfactory/utils/model_coverage.py b/polyfactory/utils/model_coverage.py index 02b31812..6fc39714 100644 --- a/polyfactory/utils/model_coverage.py +++ b/polyfactory/utils/model_coverage.py @@ -70,6 +70,11 @@ def __repr__(self) -> str: class CoverageContainerCallable(CoverageContainerBase, Generic[T]): + """A coverage container that wraps a callable. + + When calling ``next_value()`` the wrapped callable is called to provide a value. + """ + def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: self._func = func self._args = args