diff --git a/pydantic_ai_slim/pydantic_ai/format_prompt.py b/pydantic_ai_slim/pydantic_ai/format_prompt.py
index 34f06a40a..da09ee92f 100644
--- a/pydantic_ai_slim/pydantic_ai/format_prompt.py
+++ b/pydantic_ai_slim/pydantic_ai/format_prompt.py
@@ -1,7 +1,7 @@
from __future__ import annotations as _annotations
from collections.abc import Iterable, Iterator, Mapping
-from dataclasses import asdict, dataclass, is_dataclass
+from dataclasses import asdict, dataclass, field, is_dataclass
from datetime import date
from typing import Any
from xml.etree import ElementTree
@@ -10,6 +10,8 @@
__all__ = ('format_as_xml',)
+from pydantic.fields import ComputedFieldInfo, FieldInfo
+
def format_as_xml(
obj: Any,
@@ -17,6 +19,8 @@ def format_as_xml(
item_tag: str = 'item',
none_str: str = 'null',
indent: str | None = ' ',
+ include_field_info: bool = False,
+ repeat_field_info: bool = False,
) -> str:
"""Format a Python object as XML.
@@ -33,6 +37,10 @@ def format_as_xml(
for dataclasses and Pydantic models.
none_str: String to use for `None` values.
indent: Indentation string to use for pretty printing.
+ include_field_info: Whether to include attributes like Pydantic Field attributes (title, description, alias)
+ as XML attributes.
+ repeat_field_info: Whether to include XML attributes extracted from a field info for each occurrence of an XML
+ element relative to the same field.
Returns:
XML representation of the object.
@@ -51,7 +59,13 @@ def format_as_xml(
'''
```
"""
- el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
+ el = _ToXml(
+ data=obj,
+ item_tag=item_tag,
+ none_str=none_str,
+ include_field_info=include_field_info,
+ repeat_field_info=repeat_field_info,
+ ).to_xml(root_tag)
if root_tag is None and el.text is None:
join = '' if indent is None else '\n'
return join.join(_rootless_xml_elements(el, indent))
@@ -63,11 +77,24 @@ def format_as_xml(
@dataclass
class _ToXml:
+ data: Any
item_tag: str
none_str: str
-
- def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
- element = ElementTree.Element(self.item_tag if tag is None else tag)
+ include_field_info: bool
+ repeat_field_info: bool
+ # a map of Pydantic Field paths to their metadata: a field unique string representation and its class
+ _fields: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] | None = None
+ # keep track of fields we have extracted attributes from
+ _parsed_fields: set[str] = field(default_factory=set)
+ # keep track of class names for dataclasses and Pydantic models, that occur in lists
+ _element_names: dict[str, str] | None = None
+ _FIELD_ATTRIBUTES = ('title', 'description', 'alias')
+
+ def to_xml(self, tag: str | None) -> ElementTree.Element:
+ return self._to_xml(self.data, tag)
+
+ def _to_xml(self, value: Any, tag: str | None, path: str = '') -> ElementTree.Element:
+ element = self._create_element(self.item_tag if tag is None else tag, path)
if value is None:
element.text = self.none_str
elif isinstance(value, str):
@@ -79,31 +106,105 @@ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
elif isinstance(value, date):
element.text = value.isoformat()
elif isinstance(value, Mapping):
- self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
+ if tag is None and self._element_names and path in self._element_names:
+ element = self._create_element(self._element_names[path], path)
+ self._mapping_to_xml(element, value, path) # pyright: ignore[reportUnknownArgumentType]
elif is_dataclass(value) and not isinstance(value, type):
+ self._init_element_names()
if tag is None:
- element = ElementTree.Element(value.__class__.__name__)
- dc_dict = asdict(value)
- self._mapping_to_xml(element, dc_dict)
+ element = self._create_element(value.__class__.__name__, path)
+ self._mapping_to_xml(element, asdict(value), path)
elif isinstance(value, BaseModel):
+ # before serializing the model and losing all the metadata of other data structures contained in it,
+ # we extract all the fields info and class names
+ self._init_fields_info()
+ self._init_element_names()
if tag is None:
- element = ElementTree.Element(value.__class__.__name__)
- self._mapping_to_xml(element, value.model_dump(mode='python'))
+ element = self._create_element(value.__class__.__name__, path)
+ self._mapping_to_xml(element, value.model_dump(mode='python'), path)
elif isinstance(value, Iterable):
- for item in value: # pyright: ignore[reportUnknownVariableType]
- item_el = self.to_xml(item, None)
- element.append(item_el)
+ for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
+ element.append(self._to_xml(item, None, f'{path}.[{n}]' if path else f'[{n}]'))
else:
raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
return element
- def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
+ def _create_element(self, tag: str, path: str) -> ElementTree.Element:
+ element = ElementTree.Element(tag)
+ if self._fields and path in self._fields:
+ field_repr, field_info = self._fields[path]
+ if self.repeat_field_info or field_repr not in self._parsed_fields:
+ field_attributes = self._extract_attributes(field_info)
+ for k, v in field_attributes.items():
+ element.set(k, v)
+ self._parsed_fields.add(field_repr)
+ return element
+
+ def _init_fields_info(self):
+ if self.include_field_info and self._fields is None:
+ self._fields = {}
+ self._parse_data_structures(self.data, fields_map=self._fields)
+
+ def _init_element_names(self):
+ if self._element_names is None:
+ self._element_names = {}
+ self._parse_data_structures(self.data, element_names=self._element_names)
+
+ def _mapping_to_xml(
+ self,
+ element: ElementTree.Element,
+ mapping: Mapping[Any, Any],
+ path: str = '',
+ ) -> None:
for key, value in mapping.items():
if isinstance(key, int):
key = str(key)
elif not isinstance(key, str):
raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
- element.append(self.to_xml(value, key))
+ element.append(self._to_xml(value, key, f'{path}.{key}' if path else key))
+
+ @classmethod
+ def _parse_data_structures(
+ cls,
+ value: Any,
+ element_names: dict[str, str] | None = None,
+ fields_map: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] | None = None,
+ path: str = '',
+ ):
+ """Parse data structures as dataclasses or Pydantic models to extract element names and attributes."""
+ if value is None or isinstance(value, (str, int, float, date, bytearray, bytes, bool)):
+ return
+ elif isinstance(value, Mapping):
+ for k, v in value.items(): # pyright: ignore[reportUnknownVariableType]
+ cls._parse_data_structures(v, element_names, fields_map, f'{path}.{k}' if path else f'{k}')
+ elif is_dataclass(value) and not isinstance(value, type):
+ if element_names is not None:
+ element_names[path] = value.__class__.__name__
+ for k, v in asdict(value).items():
+ cls._parse_data_structures(v, element_names, fields_map, f'{path}.{k}' if path else f'{k}')
+ elif isinstance(value, BaseModel):
+ if element_names is not None:
+ element_names[path] = value.__class__.__name__
+ for model_fields in (value.__class__.model_fields, value.__class__.model_computed_fields):
+ for field, info in model_fields.items():
+ new_path = f'{path}.{field}' if path else field
+ field_repr = f'{value.__class__.__name__}.{field}'
+ if (fields_map is not None) and (isinstance(info, ComputedFieldInfo) or not info.exclude):
+ fields_map[new_path] = (field_repr, info)
+ cls._parse_data_structures(getattr(value, field), element_names, fields_map, new_path)
+ elif isinstance(value, Iterable):
+ for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
+ new_path = f'{path}.[{n}]' if path else f'[{n}]'
+ cls._parse_data_structures(item, element_names, fields_map, new_path)
+
+ @classmethod
+ def _extract_attributes(cls, info: FieldInfo | ComputedFieldInfo) -> dict[str, str]:
+ attributes: dict[str, str] = {}
+ for attr in cls._FIELD_ATTRIBUTES:
+ attr_value = getattr(info, attr, None)
+ if attr_value is not None:
+ attributes[attr] = str(attr_value)
+ return attributes
def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
diff --git a/tests/test_format_as_xml.py b/tests/test_format_as_xml.py
index 37053a67f..edcb7ca48 100644
--- a/tests/test_format_as_xml.py
+++ b/tests/test_format_as_xml.py
@@ -1,10 +1,13 @@
+from __future__ import annotations as _annotations
+
from dataclasses import dataclass
from datetime import date, datetime
from typing import Any
import pytest
from inline_snapshot import snapshot
-from pydantic import BaseModel
+from pydantic import BaseModel, Field, computed_field
+from typing_extensions import Self
from pydantic_ai import format_as_xml
@@ -20,6 +23,19 @@ class ExamplePydanticModel(BaseModel):
age: int
+class ExamplePydanticFields(BaseModel):
+ name: str = Field(description="The person's name")
+ age: int = Field(description='Years', title='Age', default=18)
+ height: float = Field(description="The person's height", exclude=True)
+ children: list[Self] | None = Field(alias='child', default=None)
+
+ @computed_field(title='Location')
+ def location(self) -> str | None:
+ if self.name == 'John':
+ return 'Australia'
+ return None
+
+
@pytest.mark.parametrize(
'input_obj,output',
[
@@ -124,7 +140,366 @@ class ExamplePydanticModel(BaseModel):
],
)
def test_root_tag(input_obj: Any, output: str):
- assert format_as_xml(input_obj, root_tag='examples', item_tag='example') == output
+ assert format_as_xml(input_obj, root_tag='examples', item_tag='example', include_field_info=False) == output
+ assert format_as_xml(input_obj, root_tag='examples', item_tag='example', include_field_info=True) == output
+
+
+@pytest.mark.parametrize(
+ 'input_obj,use_fields,output',
+ [
+ pytest.param(
+ ExamplePydanticFields(
+ name='John',
+ age=42,
+ height=160.0,
+ child=[
+ ExamplePydanticFields(name='Liam', height=150),
+ ExamplePydanticFields(name='Alice', height=160),
+ ],
+ ),
+ True,
+ snapshot("""\
+John
+42
+
+
+ Liam
+ 18
+ null
+ null
+
+
+ Alice
+ 18
+ null
+ null
+
+
+Australia\
+"""),
+ id='pydantic model with fields',
+ ),
+ pytest.param(
+ [
+ ExamplePydanticFields(
+ name='John',
+ age=42,
+ height=160.0,
+ child=[
+ ExamplePydanticFields(name='Liam', height=150),
+ ExamplePydanticFields(name='Alice', height=160),
+ ],
+ )
+ ],
+ True,
+ snapshot("""\
+
+ John
+ 42
+
+
+ Liam
+ 18
+ null
+ null
+
+
+ Alice
+ 18
+ null
+ null
+
+
+ Australia
+\
+"""),
+ id='list[pydantic model with fields]',
+ ),
+ pytest.param(
+ ExamplePydanticFields(
+ name='John',
+ age=42,
+ height=160.0,
+ child=[
+ ExamplePydanticFields(name='Liam', height=150),
+ ExamplePydanticFields(name='Alice', height=160),
+ ],
+ ),
+ False,
+ snapshot("""\
+John
+42
+
+
+ Liam
+ 18
+ null
+ null
+
+
+ Alice
+ 18
+ null
+ null
+
+
+Australia\
+"""),
+ id='pydantic model without fields',
+ ),
+ ],
+)
+def test_fields(input_obj: Any, use_fields: bool, output: str):
+ assert format_as_xml(input_obj, include_field_info=use_fields) == output
+
+
+def test_repeated_field_attributes():
+ class DataItem(BaseModel):
+ user1: ExamplePydanticFields
+ user2: ExamplePydanticFields
+
+ data = ExamplePydanticFields(
+ name='John',
+ age=42,
+ height=160.0,
+ child=[
+ ExamplePydanticFields(name='Liam', height=150),
+ ExamplePydanticFields(name='Alice', height=160),
+ ],
+ )
+ assert (
+ format_as_xml(data, include_field_info=True, repeat_field_info=True)
+ == """\
+John
+42
+
+
+ Liam
+ 18
+ null
+ null
+
+
+ Alice
+ 18
+ null
+ null
+
+
+Australia\
+"""
+ )
+
+ assert (
+ format_as_xml(DataItem(user1=data, user2=data.model_copy()), include_field_info=True, repeat_field_info=True)
+ == """\
+
+ John
+ 42
+
+
+ Liam
+ 18
+ null
+ null
+
+
+ Alice
+ 18
+ null
+ null
+
+
+ Australia
+
+
+ John
+ 42
+
+
+ Liam
+ 18
+ null
+ null
+
+
+ Alice
+ 18
+ null
+ null
+
+
+ Australia
+\
+"""
+ )
+
+ assert (
+ format_as_xml(DataItem(user1=data, user2=data.model_copy()), include_field_info=True, repeat_field_info=False)
+ == """\
+
+ John
+ 42
+
+
+ Liam
+ 18
+ null
+ null
+
+
+ Alice
+ 18
+ null
+ null
+
+
+ Australia
+
+
+ John
+ 42
+
+
+ Liam
+ 18
+ null
+ null
+
+
+ Alice
+ 18
+ null
+ null
+
+
+ Australia
+\
+"""
+ )
+
+
+def test_nested_data():
+ @dataclass
+ class DataItem1:
+ id: str | None = None
+
+ class ModelItem1(BaseModel):
+ name: str = Field(description='Name')
+ value: int
+ items: list[DataItem1] = Field(description='Items')
+
+ @dataclass
+ class DataItem2:
+ model: ModelItem1
+ others: tuple[ModelItem1] | None = None
+ count: int = 10
+
+ data = {
+ 'values': [
+ DataItem2(
+ ModelItem1(name='Alice', value=42, items=[DataItem1('xyz')]),
+ (ModelItem1(name='Liam', value=3, items=[]),),
+ ),
+ DataItem2(
+ ModelItem1(
+ name='Bob',
+ value=7,
+ items=[
+ DataItem1('a'),
+ DataItem1(),
+ ],
+ ),
+ count=42,
+ ),
+ ]
+ }
+
+ assert (
+ format_as_xml(data, include_field_info=True)
+ == """
+
+
+
+ Alice
+ 42
+
+
+ xyz
+
+
+
+
+
+ Liam
+ 3
+
+
+
+ 10
+
+
+
+ Bob
+ 7
+
+
+ a
+
+
+ null
+
+
+
+ null
+ 42
+
+
+""".strip()
+ )
+
+ assert (
+ format_as_xml(data, include_field_info=False)
+ == """
+
+
+
+ Alice
+ 42
+
+
+ xyz
+
+
+
+
+
+ Liam
+ 3
+
+
+
+ 10
+
+
+
+ Bob
+ 7
+
+
+ a
+
+
+ null
+
+
+
+ null
+ 42
+
+
+""".strip()
+ )
@pytest.mark.parametrize(
@@ -194,6 +569,15 @@ def test_invalid_key():
format_as_xml({(1, 2): 42})
+def test_parse_invalid_value():
+ class Invalid(BaseModel):
+ name: str = Field(default='Alice', title='Name')
+ bad: Any = object()
+
+ with pytest.raises(TypeError, match='Unsupported type'):
+ format_as_xml(Invalid(), include_field_info=True)
+
+
def test_set():
assert '1' in format_as_xml({1, 2, 3}, item_tag='example')