-
Notifications
You must be signed in to change notification settings - Fork 1.1k
add XML attributes when formatting Pydantic models in prompts #2313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a9b5e5f
c560f3c
bab5922
f99ff5d
3f2c5dc
051aa93
490afa4
7e1ce2f
ab47813
ad43c00
f6b0cb8
e1cbf5f
9e6f376
d9a73c8
f223496
ba5c034
595234f
07d737c
1d3473a
01d3ffd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||||||||||||||||||
from __future__ import annotations as _annotations | ||||||||||||||||||||||
|
||||||||||||||||||||||
from collections.abc import Iterable, Iterator, Mapping | ||||||||||||||||||||||
from dataclasses import asdict, dataclass, is_dataclass | ||||||||||||||||||||||
from dataclasses import asdict, dataclass, field, is_dataclass | ||||||||||||||||||||||
from datetime import date | ||||||||||||||||||||||
from typing import Any | ||||||||||||||||||||||
from xml.etree import ElementTree | ||||||||||||||||||||||
|
@@ -10,13 +10,17 @@ | |||||||||||||||||||||
|
||||||||||||||||||||||
__all__ = ('format_as_xml',) | ||||||||||||||||||||||
|
||||||||||||||||||||||
from pydantic.fields import ComputedFieldInfo, FieldInfo | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def format_as_xml( | ||||||||||||||||||||||
obj: Any, | ||||||||||||||||||||||
root_tag: str | None = None, | ||||||||||||||||||||||
item_tag: str = 'item', | ||||||||||||||||||||||
none_str: str = 'null', | ||||||||||||||||||||||
indent: str | None = ' ', | ||||||||||||||||||||||
include_field_info: bool = False, | ||||||||||||||||||||||
repeat_field_info: bool = False, | ||||||||||||||||||||||
) -> str: | ||||||||||||||||||||||
"""Format a Python object as XML. | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
@@ -33,6 +37,10 @@ def format_as_xml( | |||||||||||||||||||||
for dataclasses and Pydantic models. | ||||||||||||||||||||||
none_str: String to use for `None` values. | ||||||||||||||||||||||
indent: Indentation string to use for pretty printing. | ||||||||||||||||||||||
include_field_info: Whether to include attributes like Pydantic Field attributes (title, description, alias) | ||||||||||||||||||||||
as XML attributes. | ||||||||||||||||||||||
repeat_field_info: Whether to include XML attributes extracted from a field info for each occurrence of an XML | ||||||||||||||||||||||
element relative to the same field. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Returns: | ||||||||||||||||||||||
XML representation of the object. | ||||||||||||||||||||||
|
@@ -51,7 +59,13 @@ def format_as_xml( | |||||||||||||||||||||
''' | ||||||||||||||||||||||
``` | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag) | ||||||||||||||||||||||
el = _ToXml( | ||||||||||||||||||||||
data=obj, | ||||||||||||||||||||||
item_tag=item_tag, | ||||||||||||||||||||||
none_str=none_str, | ||||||||||||||||||||||
include_field_info=include_field_info, | ||||||||||||||||||||||
repeat_field_info=repeat_field_info, | ||||||||||||||||||||||
).to_xml(root_tag) | ||||||||||||||||||||||
if root_tag is None and el.text is None: | ||||||||||||||||||||||
join = '' if indent is None else '\n' | ||||||||||||||||||||||
return join.join(_rootless_xml_elements(el, indent)) | ||||||||||||||||||||||
|
@@ -63,11 +77,24 @@ def format_as_xml( | |||||||||||||||||||||
|
||||||||||||||||||||||
@dataclass | ||||||||||||||||||||||
class _ToXml: | ||||||||||||||||||||||
data: Any | ||||||||||||||||||||||
item_tag: str | ||||||||||||||||||||||
none_str: str | ||||||||||||||||||||||
|
||||||||||||||||||||||
def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element: | ||||||||||||||||||||||
element = ElementTree.Element(self.item_tag if tag is None else tag) | ||||||||||||||||||||||
include_field_info: bool | ||||||||||||||||||||||
repeat_field_info: bool | ||||||||||||||||||||||
# a map of Pydantic Field paths to their metadata: a field unique string representation and its class | ||||||||||||||||||||||
_fields: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] | None = None | ||||||||||||||||||||||
# keep track of fields we have extracted attributes from | ||||||||||||||||||||||
_parsed_fields: set[str] = field(default_factory=set) | ||||||||||||||||||||||
# keep track of class names for dataclasses and Pydantic models, that occur in lists | ||||||||||||||||||||||
_element_names: dict[str, str] | None = None | ||||||||||||||||||||||
_FIELD_ATTRIBUTES = ('title', 'description', 'alias') | ||||||||||||||||||||||
|
||||||||||||||||||||||
def to_xml(self, tag: str | None) -> ElementTree.Element: | ||||||||||||||||||||||
return self._to_xml(self.data, tag) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _to_xml(self, value: Any, tag: str | None, path: str = '') -> ElementTree.Element: | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||||||||||||||||||||||
element = self._create_element(self.item_tag if tag is None else tag, path) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We create a new |
||||||||||||||||||||||
if value is None: | ||||||||||||||||||||||
element.text = self.none_str | ||||||||||||||||||||||
elif isinstance(value, str): | ||||||||||||||||||||||
|
@@ -79,31 +106,105 @@ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element: | |||||||||||||||||||||
elif isinstance(value, date): | ||||||||||||||||||||||
element.text = value.isoformat() | ||||||||||||||||||||||
elif isinstance(value, Mapping): | ||||||||||||||||||||||
self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType] | ||||||||||||||||||||||
if tag is None and self._element_names and path in self._element_names: | ||||||||||||||||||||||
element = self._create_element(self._element_names[path], path) | ||||||||||||||||||||||
self._mapping_to_xml(element, value, path) # pyright: ignore[reportUnknownArgumentType] | ||||||||||||||||||||||
elif is_dataclass(value) and not isinstance(value, type): | ||||||||||||||||||||||
self._init_element_names() | ||||||||||||||||||||||
if tag is None: | ||||||||||||||||||||||
element = ElementTree.Element(value.__class__.__name__) | ||||||||||||||||||||||
dc_dict = asdict(value) | ||||||||||||||||||||||
self._mapping_to_xml(element, dc_dict) | ||||||||||||||||||||||
element = self._create_element(value.__class__.__name__, path) | ||||||||||||||||||||||
self._mapping_to_xml(element, asdict(value), path) | ||||||||||||||||||||||
elif isinstance(value, BaseModel): | ||||||||||||||||||||||
# before serializing the model and losing all the metadata of other data structures contained in it, | ||||||||||||||||||||||
# we extract all the fields info and class names | ||||||||||||||||||||||
self._init_fields_info() | ||||||||||||||||||||||
self._init_element_names() | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These 2 calls end up calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combined with my suggestion to always initialize |
||||||||||||||||||||||
if tag is None: | ||||||||||||||||||||||
element = ElementTree.Element(value.__class__.__name__) | ||||||||||||||||||||||
self._mapping_to_xml(element, value.model_dump(mode='python')) | ||||||||||||||||||||||
element = self._create_element(value.__class__.__name__, path) | ||||||||||||||||||||||
self._mapping_to_xml(element, value.model_dump(mode='python'), path) | ||||||||||||||||||||||
elif isinstance(value, Iterable): | ||||||||||||||||||||||
for item in value: # pyright: ignore[reportUnknownVariableType] | ||||||||||||||||||||||
item_el = self.to_xml(item, None) | ||||||||||||||||||||||
element.append(item_el) | ||||||||||||||||||||||
for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType] | ||||||||||||||||||||||
element.append(self._to_xml(item, None, f'{path}.[{n}]' if path else f'[{n}]')) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||||||||||||||||||||||
else: | ||||||||||||||||||||||
raise TypeError(f'Unsupported type for XML formatting: {type(value)}') | ||||||||||||||||||||||
return element | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None: | ||||||||||||||||||||||
def _create_element(self, tag: str, path: str) -> ElementTree.Element: | ||||||||||||||||||||||
element = ElementTree.Element(tag) | ||||||||||||||||||||||
if self._fields and path in self._fields: | ||||||||||||||||||||||
field_repr, field_info = self._fields[path] | ||||||||||||||||||||||
if self.repeat_field_info or field_repr not in self._parsed_fields: | ||||||||||||||||||||||
field_attributes = self._extract_attributes(field_info) | ||||||||||||||||||||||
for k, v in field_attributes.items(): | ||||||||||||||||||||||
element.set(k, v) | ||||||||||||||||||||||
self._parsed_fields.add(field_repr) | ||||||||||||||||||||||
return element | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _init_fields_info(self): | ||||||||||||||||||||||
if self.include_field_info and self._fields is None: | ||||||||||||||||||||||
self._fields = {} | ||||||||||||||||||||||
self._parse_data_structures(self.data, fields_map=self._fields) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _init_element_names(self): | ||||||||||||||||||||||
if self._element_names is None: | ||||||||||||||||||||||
self._element_names = {} | ||||||||||||||||||||||
self._parse_data_structures(self.data, element_names=self._element_names) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _mapping_to_xml( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
element: ElementTree.Element, | ||||||||||||||||||||||
mapping: Mapping[Any, Any], | ||||||||||||||||||||||
path: str = '', | ||||||||||||||||||||||
) -> None: | ||||||||||||||||||||||
for key, value in mapping.items(): | ||||||||||||||||||||||
if isinstance(key, int): | ||||||||||||||||||||||
key = str(key) | ||||||||||||||||||||||
elif not isinstance(key, str): | ||||||||||||||||||||||
raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed') | ||||||||||||||||||||||
element.append(self.to_xml(value, key)) | ||||||||||||||||||||||
element.append(self._to_xml(value, key, f'{path}.{key}' if path else key)) | ||||||||||||||||||||||
|
||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||
def _parse_data_structures( | ||||||||||||||||||||||
cls, | ||||||||||||||||||||||
value: Any, | ||||||||||||||||||||||
element_names: dict[str, str] | None = None, | ||||||||||||||||||||||
fields_map: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] | None = None, | ||||||||||||||||||||||
path: str = '', | ||||||||||||||||||||||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
): | ||||||||||||||||||||||
"""Parse data structures as dataclasses or Pydantic models to extract element names and attributes.""" | ||||||||||||||||||||||
if value is None or isinstance(value, (str, int, float, date, bytearray, bytes, bool)): | ||||||||||||||||||||||
return | ||||||||||||||||||||||
elif isinstance(value, Mapping): | ||||||||||||||||||||||
for k, v in value.items(): # pyright: ignore[reportUnknownVariableType] | ||||||||||||||||||||||
cls._parse_data_structures(v, element_names, fields_map, f'{path}.{k}' if path else f'{k}') | ||||||||||||||||||||||
elif is_dataclass(value) and not isinstance(value, type): | ||||||||||||||||||||||
if element_names is not None: | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we give Same for |
||||||||||||||||||||||
element_names[path] = value.__class__.__name__ | ||||||||||||||||||||||
for k, v in asdict(value).items(): | ||||||||||||||||||||||
cls._parse_data_structures(v, element_names, fields_map, f'{path}.{k}' if path else f'{k}') | ||||||||||||||||||||||
elif isinstance(value, BaseModel): | ||||||||||||||||||||||
Comment on lines
+184
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dataclass fields can also have descriptions, via We may want to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may not be the worst idea to use a TypeAdapter anyway, create JSON and JSON schema, and then use those to build the XML, so we don't have to handle dataclasses and BaseModels ourselves at all. That may be complicated with $refs and $defs though... |
||||||||||||||||||||||
if element_names is not None: | ||||||||||||||||||||||
element_names[path] = value.__class__.__name__ | ||||||||||||||||||||||
for model_fields in (value.__class__.model_fields, value.__class__.model_computed_fields): | ||||||||||||||||||||||
for field, info in model_fields.items(): | ||||||||||||||||||||||
new_path = f'{path}.{field}' if path else field | ||||||||||||||||||||||
field_repr = f'{value.__class__.__name__}.{field}' | ||||||||||||||||||||||
if (fields_map is not None) and (isinstance(info, ComputedFieldInfo) or not info.exclude): | ||||||||||||||||||||||
fields_map[new_path] = (field_repr, info) | ||||||||||||||||||||||
cls._parse_data_structures(getattr(value, field), element_names, fields_map, new_path) | ||||||||||||||||||||||
elif isinstance(value, Iterable): | ||||||||||||||||||||||
for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType] | ||||||||||||||||||||||
new_path = f'{path}.[{n}]' if path else f'[{n}]' | ||||||||||||||||||||||
cls._parse_data_structures(item, element_names, fields_map, new_path) | ||||||||||||||||||||||
|
||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||
def _extract_attributes(cls, info: FieldInfo | ComputedFieldInfo) -> dict[str, str]: | ||||||||||||||||||||||
attributes: dict[str, str] = {} | ||||||||||||||||||||||
for attr in cls._FIELD_ATTRIBUTES: | ||||||||||||||||||||||
attr_value = getattr(info, attr, None) | ||||||||||||||||||||||
if attr_value is not None: | ||||||||||||||||||||||
attributes[attr] = str(attr_value) | ||||||||||||||||||||||
return attributes | ||||||||||||||||||||||
Comment on lines
+203
to
+207
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be a oneliner, so we may not need a method:
Suggested change
|
||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]: | ||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This more like
included_fields
right?