diff --git a/pydantic_ai_slim/pydantic_ai/format_prompt.py b/pydantic_ai_slim/pydantic_ai/format_prompt.py index 34f06a40a..da09ee92f 100644 --- a/pydantic_ai_slim/pydantic_ai/format_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/format_prompt.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from collections.abc import Iterable, Iterator, Mapping -from dataclasses import asdict, dataclass, is_dataclass +from dataclasses import asdict, dataclass, field, is_dataclass from datetime import date from typing import Any from xml.etree import ElementTree @@ -10,6 +10,8 @@ __all__ = ('format_as_xml',) +from pydantic.fields import ComputedFieldInfo, FieldInfo + def format_as_xml( obj: Any, @@ -17,6 +19,8 @@ def format_as_xml( item_tag: str = 'item', none_str: str = 'null', indent: str | None = ' ', + include_field_info: bool = False, + repeat_field_info: bool = False, ) -> str: """Format a Python object as XML. @@ -33,6 +37,10 @@ def format_as_xml( for dataclasses and Pydantic models. none_str: String to use for `None` values. indent: Indentation string to use for pretty printing. + include_field_info: Whether to include attributes like Pydantic Field attributes (title, description, alias) + as XML attributes. + repeat_field_info: Whether to include XML attributes extracted from a field info for each occurrence of an XML + element relative to the same field. Returns: XML representation of the object. @@ -51,7 +59,13 @@ def format_as_xml( ''' ``` """ - el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag) + el = _ToXml( + data=obj, + item_tag=item_tag, + none_str=none_str, + include_field_info=include_field_info, + repeat_field_info=repeat_field_info, + ).to_xml(root_tag) if root_tag is None and el.text is None: join = '' if indent is None else '\n' return join.join(_rootless_xml_elements(el, indent)) @@ -63,11 +77,24 @@ def format_as_xml( @dataclass class _ToXml: + data: Any item_tag: str none_str: str - - def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element: - element = ElementTree.Element(self.item_tag if tag is None else tag) + include_field_info: bool + repeat_field_info: bool + # a map of Pydantic Field paths to their metadata: a field unique string representation and its class + _fields: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] | None = None + # keep track of fields we have extracted attributes from + _parsed_fields: set[str] = field(default_factory=set) + # keep track of class names for dataclasses and Pydantic models, that occur in lists + _element_names: dict[str, str] | None = None + _FIELD_ATTRIBUTES = ('title', 'description', 'alias') + + def to_xml(self, tag: str | None) -> ElementTree.Element: + return self._to_xml(self.data, tag) + + def _to_xml(self, value: Any, tag: str | None, path: str = '') -> ElementTree.Element: + element = self._create_element(self.item_tag if tag is None else tag, path) if value is None: element.text = self.none_str elif isinstance(value, str): @@ -79,31 +106,105 @@ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element: elif isinstance(value, date): element.text = value.isoformat() elif isinstance(value, Mapping): - self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType] + if tag is None and self._element_names and path in self._element_names: + element = self._create_element(self._element_names[path], path) + self._mapping_to_xml(element, value, path) # pyright: ignore[reportUnknownArgumentType] elif is_dataclass(value) and not isinstance(value, type): + self._init_element_names() if tag is None: - element = ElementTree.Element(value.__class__.__name__) - dc_dict = asdict(value) - self._mapping_to_xml(element, dc_dict) + element = self._create_element(value.__class__.__name__, path) + self._mapping_to_xml(element, asdict(value), path) elif isinstance(value, BaseModel): + # before serializing the model and losing all the metadata of other data structures contained in it, + # we extract all the fields info and class names + self._init_fields_info() + self._init_element_names() if tag is None: - element = ElementTree.Element(value.__class__.__name__) - self._mapping_to_xml(element, value.model_dump(mode='python')) + element = self._create_element(value.__class__.__name__, path) + self._mapping_to_xml(element, value.model_dump(mode='python'), path) elif isinstance(value, Iterable): - for item in value: # pyright: ignore[reportUnknownVariableType] - item_el = self.to_xml(item, None) - element.append(item_el) + for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType] + element.append(self._to_xml(item, None, f'{path}.[{n}]' if path else f'[{n}]')) else: raise TypeError(f'Unsupported type for XML formatting: {type(value)}') return element - def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None: + def _create_element(self, tag: str, path: str) -> ElementTree.Element: + element = ElementTree.Element(tag) + if self._fields and path in self._fields: + field_repr, field_info = self._fields[path] + if self.repeat_field_info or field_repr not in self._parsed_fields: + field_attributes = self._extract_attributes(field_info) + for k, v in field_attributes.items(): + element.set(k, v) + self._parsed_fields.add(field_repr) + return element + + def _init_fields_info(self): + if self.include_field_info and self._fields is None: + self._fields = {} + self._parse_data_structures(self.data, fields_map=self._fields) + + def _init_element_names(self): + if self._element_names is None: + self._element_names = {} + self._parse_data_structures(self.data, element_names=self._element_names) + + def _mapping_to_xml( + self, + element: ElementTree.Element, + mapping: Mapping[Any, Any], + path: str = '', + ) -> None: for key, value in mapping.items(): if isinstance(key, int): key = str(key) elif not isinstance(key, str): raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed') - element.append(self.to_xml(value, key)) + element.append(self._to_xml(value, key, f'{path}.{key}' if path else key)) + + @classmethod + def _parse_data_structures( + cls, + value: Any, + element_names: dict[str, str] | None = None, + fields_map: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] | None = None, + path: str = '', + ): + """Parse data structures as dataclasses or Pydantic models to extract element names and attributes.""" + if value is None or isinstance(value, (str, int, float, date, bytearray, bytes, bool)): + return + elif isinstance(value, Mapping): + for k, v in value.items(): # pyright: ignore[reportUnknownVariableType] + cls._parse_data_structures(v, element_names, fields_map, f'{path}.{k}' if path else f'{k}') + elif is_dataclass(value) and not isinstance(value, type): + if element_names is not None: + element_names[path] = value.__class__.__name__ + for k, v in asdict(value).items(): + cls._parse_data_structures(v, element_names, fields_map, f'{path}.{k}' if path else f'{k}') + elif isinstance(value, BaseModel): + if element_names is not None: + element_names[path] = value.__class__.__name__ + for model_fields in (value.__class__.model_fields, value.__class__.model_computed_fields): + for field, info in model_fields.items(): + new_path = f'{path}.{field}' if path else field + field_repr = f'{value.__class__.__name__}.{field}' + if (fields_map is not None) and (isinstance(info, ComputedFieldInfo) or not info.exclude): + fields_map[new_path] = (field_repr, info) + cls._parse_data_structures(getattr(value, field), element_names, fields_map, new_path) + elif isinstance(value, Iterable): + for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType] + new_path = f'{path}.[{n}]' if path else f'[{n}]' + cls._parse_data_structures(item, element_names, fields_map, new_path) + + @classmethod + def _extract_attributes(cls, info: FieldInfo | ComputedFieldInfo) -> dict[str, str]: + attributes: dict[str, str] = {} + for attr in cls._FIELD_ATTRIBUTES: + attr_value = getattr(info, attr, None) + if attr_value is not None: + attributes[attr] = str(attr_value) + return attributes def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]: diff --git a/tests/test_format_as_xml.py b/tests/test_format_as_xml.py index 37053a67f..edcb7ca48 100644 --- a/tests/test_format_as_xml.py +++ b/tests/test_format_as_xml.py @@ -1,10 +1,13 @@ +from __future__ import annotations as _annotations + from dataclasses import dataclass from datetime import date, datetime from typing import Any import pytest from inline_snapshot import snapshot -from pydantic import BaseModel +from pydantic import BaseModel, Field, computed_field +from typing_extensions import Self from pydantic_ai import format_as_xml @@ -20,6 +23,19 @@ class ExamplePydanticModel(BaseModel): age: int +class ExamplePydanticFields(BaseModel): + name: str = Field(description="The person's name") + age: int = Field(description='Years', title='Age', default=18) + height: float = Field(description="The person's height", exclude=True) + children: list[Self] | None = Field(alias='child', default=None) + + @computed_field(title='Location') + def location(self) -> str | None: + if self.name == 'John': + return 'Australia' + return None + + @pytest.mark.parametrize( 'input_obj,output', [ @@ -124,7 +140,366 @@ class ExamplePydanticModel(BaseModel): ], ) def test_root_tag(input_obj: Any, output: str): - assert format_as_xml(input_obj, root_tag='examples', item_tag='example') == output + assert format_as_xml(input_obj, root_tag='examples', item_tag='example', include_field_info=False) == output + assert format_as_xml(input_obj, root_tag='examples', item_tag='example', include_field_info=True) == output + + +@pytest.mark.parametrize( + 'input_obj,use_fields,output', + [ + pytest.param( + ExamplePydanticFields( + name='John', + age=42, + height=160.0, + child=[ + ExamplePydanticFields(name='Liam', height=150), + ExamplePydanticFields(name='Alice', height=160), + ], + ), + True, + snapshot("""\ +John +42 + + + Liam + 18 + null + null + + + Alice + 18 + null + null + + +Australia\ +"""), + id='pydantic model with fields', + ), + pytest.param( + [ + ExamplePydanticFields( + name='John', + age=42, + height=160.0, + child=[ + ExamplePydanticFields(name='Liam', height=150), + ExamplePydanticFields(name='Alice', height=160), + ], + ) + ], + True, + snapshot("""\ + + John + 42 + + + Liam + 18 + null + null + + + Alice + 18 + null + null + + + Australia +\ +"""), + id='list[pydantic model with fields]', + ), + pytest.param( + ExamplePydanticFields( + name='John', + age=42, + height=160.0, + child=[ + ExamplePydanticFields(name='Liam', height=150), + ExamplePydanticFields(name='Alice', height=160), + ], + ), + False, + snapshot("""\ +John +42 + + + Liam + 18 + null + null + + + Alice + 18 + null + null + + +Australia\ +"""), + id='pydantic model without fields', + ), + ], +) +def test_fields(input_obj: Any, use_fields: bool, output: str): + assert format_as_xml(input_obj, include_field_info=use_fields) == output + + +def test_repeated_field_attributes(): + class DataItem(BaseModel): + user1: ExamplePydanticFields + user2: ExamplePydanticFields + + data = ExamplePydanticFields( + name='John', + age=42, + height=160.0, + child=[ + ExamplePydanticFields(name='Liam', height=150), + ExamplePydanticFields(name='Alice', height=160), + ], + ) + assert ( + format_as_xml(data, include_field_info=True, repeat_field_info=True) + == """\ +John +42 + + + Liam + 18 + null + null + + + Alice + 18 + null + null + + +Australia\ +""" + ) + + assert ( + format_as_xml(DataItem(user1=data, user2=data.model_copy()), include_field_info=True, repeat_field_info=True) + == """\ + + John + 42 + + + Liam + 18 + null + null + + + Alice + 18 + null + null + + + Australia + + + John + 42 + + + Liam + 18 + null + null + + + Alice + 18 + null + null + + + Australia +\ +""" + ) + + assert ( + format_as_xml(DataItem(user1=data, user2=data.model_copy()), include_field_info=True, repeat_field_info=False) + == """\ + + John + 42 + + + Liam + 18 + null + null + + + Alice + 18 + null + null + + + Australia + + + John + 42 + + + Liam + 18 + null + null + + + Alice + 18 + null + null + + + Australia +\ +""" + ) + + +def test_nested_data(): + @dataclass + class DataItem1: + id: str | None = None + + class ModelItem1(BaseModel): + name: str = Field(description='Name') + value: int + items: list[DataItem1] = Field(description='Items') + + @dataclass + class DataItem2: + model: ModelItem1 + others: tuple[ModelItem1] | None = None + count: int = 10 + + data = { + 'values': [ + DataItem2( + ModelItem1(name='Alice', value=42, items=[DataItem1('xyz')]), + (ModelItem1(name='Liam', value=3, items=[]),), + ), + DataItem2( + ModelItem1( + name='Bob', + value=7, + items=[ + DataItem1('a'), + DataItem1(), + ], + ), + count=42, + ), + ] + } + + assert ( + format_as_xml(data, include_field_info=True) + == """ + + + + Alice + 42 + + + xyz + + + + + + Liam + 3 + + + + 10 + + + + Bob + 7 + + + a + + + null + + + + null + 42 + + +""".strip() + ) + + assert ( + format_as_xml(data, include_field_info=False) + == """ + + + + Alice + 42 + + + xyz + + + + + + Liam + 3 + + + + 10 + + + + Bob + 7 + + + a + + + null + + + + null + 42 + + +""".strip() + ) @pytest.mark.parametrize( @@ -194,6 +569,15 @@ def test_invalid_key(): format_as_xml({(1, 2): 42}) +def test_parse_invalid_value(): + class Invalid(BaseModel): + name: str = Field(default='Alice', title='Name') + bad: Any = object() + + with pytest.raises(TypeError, match='Unsupported type'): + format_as_xml(Invalid(), include_field_info=True) + + def test_set(): assert '1' in format_as_xml({1, 2, 3}, item_tag='example')