Skip to content

add XML attributes when formatting Pydantic models in prompts #2313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 117 additions & 16 deletions pydantic_ai_slim/pydantic_ai/format_prompt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations as _annotations

from collections.abc import Iterable, Iterator, Mapping
from dataclasses import asdict, dataclass, is_dataclass
from dataclasses import asdict, dataclass, field, is_dataclass
from datetime import date
from typing import Any
from xml.etree import ElementTree
Expand All @@ -10,13 +10,17 @@

__all__ = ('format_as_xml',)

from pydantic.fields import ComputedFieldInfo, FieldInfo


def format_as_xml(
obj: Any,
root_tag: str | None = None,
item_tag: str = 'item',
none_str: str = 'null',
indent: str | None = ' ',
include_field_info: bool = False,
repeat_field_info: bool = False,
) -> str:
"""Format a Python object as XML.

Expand All @@ -33,6 +37,10 @@ def format_as_xml(
for dataclasses and Pydantic models.
none_str: String to use for `None` values.
indent: Indentation string to use for pretty printing.
include_field_info: Whether to include attributes like Pydantic Field attributes (title, description, alias)
as XML attributes.
repeat_field_info: Whether to include XML attributes extracted from a field info for each occurrence of an XML
element relative to the same field.

Returns:
XML representation of the object.
Expand All @@ -51,7 +59,13 @@ def format_as_xml(
'''
```
"""
el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
el = _ToXml(
data=obj,
item_tag=item_tag,
none_str=none_str,
include_field_info=include_field_info,
repeat_field_info=repeat_field_info,
).to_xml(root_tag)
if root_tag is None and el.text is None:
join = '' if indent is None else '\n'
return join.join(_rootless_xml_elements(el, indent))
Expand All @@ -63,11 +77,24 @@ def format_as_xml(

@dataclass
class _ToXml:
data: Any
item_tag: str
none_str: str

def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
element = ElementTree.Element(self.item_tag if tag is None else tag)
include_field_info: bool
repeat_field_info: bool
# a map of Pydantic Field paths to their metadata: a field unique string representation and its class
_fields: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] | None = None
# keep track of fields we have extracted attributes from
_parsed_fields: set[str] = field(default_factory=set)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This more like included_fields right?

# keep track of class names for dataclasses and Pydantic models, that occur in lists
_element_names: dict[str, str] | None = None
_FIELD_ATTRIBUTES = ('title', 'description', 'alias')

def to_xml(self, tag: str | None) -> ElementTree.Element:
return self._to_xml(self.data, tag)

def _to_xml(self, value: Any, tag: str | None, path: str = '') -> ElementTree.Element:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If path should only be omitted for the root node, I think we should make it required and pass '' explicitly there

element = self._create_element(self.item_tag if tag is None else tag, path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We create a new element in some cases below, can we change this to only build the element we're actually going to use?

if value is None:
element.text = self.none_str
elif isinstance(value, str):
Expand All @@ -79,31 +106,105 @@ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
elif isinstance(value, date):
element.text = value.isoformat()
elif isinstance(value, Mapping):
self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
if tag is None and self._element_names and path in self._element_names:
element = self._create_element(self._element_names[path], path)
self._mapping_to_xml(element, value, path) # pyright: ignore[reportUnknownArgumentType]
elif is_dataclass(value) and not isinstance(value, type):
self._init_element_names()
if tag is None:
element = ElementTree.Element(value.__class__.__name__)
dc_dict = asdict(value)
self._mapping_to_xml(element, dc_dict)
element = self._create_element(value.__class__.__name__, path)
self._mapping_to_xml(element, asdict(value), path)
elif isinstance(value, BaseModel):
# before serializing the model and losing all the metadata of other data structures contained in it,
# we extract all the fields info and class names
self._init_fields_info()
self._init_element_names()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These 2 calls end up calling _parse_data_structures twice, could we do it just once?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined with my suggestion to always initialize _fields and _element_names as empty dicts, I think we can call self._parse_data_structures(self.data) when we see a BaseModel or dataclass and handle which (or both) of the two to populate in there

if tag is None:
element = ElementTree.Element(value.__class__.__name__)
self._mapping_to_xml(element, value.model_dump(mode='python'))
element = self._create_element(value.__class__.__name__, path)
self._mapping_to_xml(element, value.model_dump(mode='python'), path)
elif isinstance(value, Iterable):
for item in value: # pyright: ignore[reportUnknownVariableType]
item_el = self.to_xml(item, None)
element.append(item_el)
for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
element.append(self._to_xml(item, None, f'{path}.[{n}]' if path else f'[{n}]'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _to_xml tag can be None, can we make that a default value so we can skip passing None here?

else:
raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
return element

def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
def _create_element(self, tag: str, path: str) -> ElementTree.Element:
element = ElementTree.Element(tag)
if self._fields and path in self._fields:
field_repr, field_info = self._fields[path]
if self.repeat_field_info or field_repr not in self._parsed_fields:
field_attributes = self._extract_attributes(field_info)
for k, v in field_attributes.items():
element.set(k, v)
self._parsed_fields.add(field_repr)
return element

def _init_fields_info(self):
if self.include_field_info and self._fields is None:
self._fields = {}
self._parse_data_structures(self.data, fields_map=self._fields)

def _init_element_names(self):
if self._element_names is None:
self._element_names = {}
self._parse_data_structures(self.data, element_names=self._element_names)

def _mapping_to_xml(
self,
element: ElementTree.Element,
mapping: Mapping[Any, Any],
path: str = '',
) -> None:
for key, value in mapping.items():
if isinstance(key, int):
key = str(key)
elif not isinstance(key, str):
raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
element.append(self.to_xml(value, key))
element.append(self._to_xml(value, key, f'{path}.{key}' if path else key))

@classmethod
def _parse_data_structures(
cls,
value: Any,
element_names: dict[str, str] | None = None,
fields_map: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] | None = None,
path: str = '',
):
"""Parse data structures as dataclasses or Pydantic models to extract element names and attributes."""
if value is None or isinstance(value, (str, int, float, date, bytearray, bytes, bool)):
return
elif isinstance(value, Mapping):
for k, v in value.items(): # pyright: ignore[reportUnknownVariableType]
cls._parse_data_structures(v, element_names, fields_map, f'{path}.{k}' if path else f'{k}')
elif is_dataclass(value) and not isinstance(value, type):
if element_names is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we give self._element_names a default value of {} and always wriet directly into that instead of checking for None and passing element_names around as an arg?

Same for fields_map

element_names[path] = value.__class__.__name__
for k, v in asdict(value).items():
cls._parse_data_structures(v, element_names, fields_map, f'{path}.{k}' if path else f'{k}')
elif isinstance(value, BaseModel):
Comment on lines +184 to +185
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataclass fields can also have descriptions, via field(metadata=) or Pydantic Field. See also https://docs.pydantic.dev/latest/concepts/dataclasses/. Any chance we can pull those out as well?

We may want to use TypeAdapter (as documented there) and use its JSON schema to get the values as that handles both dataclasses and basemodels already. Or if not use it directly, see how it does it and if we can use those same methods

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be the worst idea to use a TypeAdapter anyway, create JSON and JSON schema, and then use those to build the XML, so we don't have to handle dataclasses and BaseModels ourselves at all. That may be complicated with $refs and $defs though...

if element_names is not None:
element_names[path] = value.__class__.__name__
for model_fields in (value.__class__.model_fields, value.__class__.model_computed_fields):
for field, info in model_fields.items():
new_path = f'{path}.{field}' if path else field
field_repr = f'{value.__class__.__name__}.{field}'
if (fields_map is not None) and (isinstance(info, ComputedFieldInfo) or not info.exclude):
fields_map[new_path] = (field_repr, info)
cls._parse_data_structures(getattr(value, field), element_names, fields_map, new_path)
elif isinstance(value, Iterable):
for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
new_path = f'{path}.[{n}]' if path else f'[{n}]'
cls._parse_data_structures(item, element_names, fields_map, new_path)

@classmethod
def _extract_attributes(cls, info: FieldInfo | ComputedFieldInfo) -> dict[str, str]:
attributes: dict[str, str] = {}
for attr in cls._FIELD_ATTRIBUTES:
attr_value = getattr(info, attr, None)
if attr_value is not None:
attributes[attr] = str(attr_value)
return attributes
Comment on lines +203 to +207
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a oneliner, so we may not need a method:

Suggested change
for attr in cls._FIELD_ATTRIBUTES:
attr_value = getattr(info, attr, None)
if attr_value is not None:
attributes[attr] = str(attr_value)
return attributes
return {
attr: str(value)
for attr in cls._FIELD_ATTRIBUTES
if (value := getattr(info, attr, None)) is not None
}



def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
Expand Down
Loading