Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .chronus/changes/external-type-python-2025-9-20-14-28-41.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
changeKind: feature
packages:
- "@typespec/http-client-python"
---

Support SDK users defined customized serialization/deserialization function for external models
6 changes: 6 additions & 0 deletions packages/http-client-python/emitter/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ function emitModel(context: PythonSdkContext, type: SdkModelType): Record<string
submodule: "exceptions",
};
}
if (type.external) {
return getSimpleTypeResult({
type: "external",
externalTypeInfo: type.external,
});
}
const parents: Record<string, any>[] = [];
const newValue = {
type: type.kind,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
SdkCoreType,
DecimalType,
MultiPartFileType,
ExternalType,
)
from .enum_type import EnumType, EnumValue
from .base import BaseType
Expand Down Expand Up @@ -151,6 +152,7 @@
"credential": StringType,
"sdkcore": SdkCoreType,
"multipartfile": MultiPartFileType,
"external": ExternalType,
}
_LOGGER = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .enum_type import EnumType
from .model_type import ModelType, UsageFlags
from .combined_type import CombinedType
from .primitive_types import ExternalType
from .client import Client
from .request_builder import RequestBuilder, OverloadedRequestBuilder
from .operation_group import OperationGroup
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
self._operations_folder_name: dict[str, str] = {}
self._relative_import_path: dict[str, str] = {}
self.metadata: dict[str, Any] = yaml_data.get("metadata", {})
self.has_external_type = any(isinstance(t, ExternalType) for t in self.types_map.values())

@staticmethod
def get_imported_namespace_for_client(imported_namespace: str, async_mode: bool = False) -> str:
Expand Down Expand Up @@ -488,3 +490,8 @@ def _get_relative_generation_dir(self, root_dir: Path, namespace: str) -> Path:
@property
def has_operation_named_list(self) -> bool:
return any(o.name.lower() == "list" for c in self.clients for og in c.operation_groups for o in og.operations)

@property
def external_types(self) -> list[ExternalType]:
"""All of the external types"""
return [t for t in self.types_map.values() if isinstance(t, ExternalType)]
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,39 @@ def serialization_type(self, **kwargs: Any) -> str:
return self.name


class ExternalType(PrimitiveType):
def __init__(self, yaml_data: dict[str, Any], code_model: "CodeModel") -> None:
super().__init__(yaml_data=yaml_data, code_model=code_model)
self.external_type_info = yaml_data.get("externalTypeInfo", {})
self.identity = self.external_type_info.get("identity", "")
self.submodule = ".".join(self.identity.split(".")[:-1])
self.min_version = self.external_type_info.get("minVersion", "")
self.package_name = self.external_type_info.get("package", "")

def docstring_type(self, **kwargs: Any) -> str:
return f"~{self.identity}"

def type_annotation(self, **kwargs: Any) -> str:
return self.identity

def imports(self, **kwargs: Any) -> FileImport:
file_import = super().imports(**kwargs)
file_import.add_import(self.submodule, ImportType.THIRDPARTY, TypingSection.REGULAR)
return file_import

@property
def instance_check_template(self) -> str:
return f"isinstance({{}}, {self.identity})"

def serialization_type(self, **kwargs: Any) -> str:
return self.identity

@property
def default_template_representation_declaration(self) -> str:
value = f"{self.identity}(...)"
return f'"{value}"' if self.code_model.for_test else value


class MultiPartFileType(PrimitiveType):
def __init__(self, yaml_data: dict[str, Any], code_model: "CodeModel") -> None:
super().__init__(yaml_data=yaml_data, code_model=code_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"msrest": "0.7.1",
"isodate": "0.6.1",
"azure-mgmt-core": "1.6.0",
"azure-core": "1.35.0",
"azure-core": "1.36.0",
"typing-extensions": "4.6.0",
"corehttp": "1.0.0b6",
}
Expand Down Expand Up @@ -57,7 +57,16 @@ def _extract_min_dependency(self, s):
m = re.search(r"[>=]=?([\d.]+(?:[a-z]+\d+)?)", s)
return parse_version(m.group(1)) if m else parse_version("0")

def _keep_pyproject_fields(self, file_content: str) -> dict:
def _update_version_map(self, version_map: dict[str, str], dep_name: str, dep: str) -> None:
# For tracked dependencies, check if the version is higher than our default
default_version = parse_version(version_map[dep_name])
dep_version = self._extract_min_dependency(dep)
# If the version is higher than the default, update VERSION_MAP
# with higher min dependency version
if dep_version > default_version:
version_map[dep_name] = str(dep_version)

def _keep_pyproject_fields(self, file_content: str, additional_version_map: dict[str, str]) -> dict:
# Load the pyproject.toml file if it exists and extract fields to keep.
result: dict = {"KEEP_FIELDS": {}}
try:
Expand All @@ -80,15 +89,11 @@ def _keep_pyproject_fields(self, file_content: str) -> dict:
for dep in loaded_pyproject_toml["project"]["dependencies"]:
dep_name = re.split(r"[<>=\[]", dep)[0].strip()

# Check if dependency is one we track in VERSION_MAP
# Check if dependency is one we track in version map
if dep_name in VERSION_MAP:
# For tracked dependencies, check if the version is higher than our default
default_version = parse_version(VERSION_MAP[dep_name])
dep_version = self._extract_min_dependency(dep)
# If the version is higher than the default, update VERSION_MAP
# with higher min dependency version
if dep_version > default_version:
VERSION_MAP[dep_name] = str(dep_version)
self._update_version_map(VERSION_MAP, dep_name, dep)
elif dep_name in additional_version_map:
self._update_version_map(additional_version_map, dep_name, dep)
else:
# Keep non-default dependencies
kept_deps.append(dep)
Expand All @@ -107,9 +112,18 @@ def _keep_pyproject_fields(self, file_content: str) -> dict:
def serialize_package_file(self, template_name: str, file_content: str, **kwargs: Any) -> str:
template = self.env.get_template(template_name)

additional_version_map = {}
if self.code_model.has_external_type:
for item in self.code_model.external_types:
if item.package_name:
if item.min_version:
additional_version_map[item.package_name] = item.min_version
else:
additional_version_map[item.package_name] = "0"

# Add fields to keep from an existing pyproject.toml
if template_name == "pyproject.toml.jinja2":
params = self._keep_pyproject_fields(file_content)
params = self._keep_pyproject_fields(file_content, additional_version_map)
else:
params = {}

Expand All @@ -126,6 +140,7 @@ def serialize_package_file(self, template_name: str, file_content: str, **kwargs
dev_status = "4 - Beta"
else:
dev_status = "5 - Production/Stable"

params |= {
"code_model": self.code_model,
"dev_status": dev_status,
Expand All @@ -136,6 +151,7 @@ def serialize_package_file(self, template_name: str, file_content: str, **kwargs
"VERSION_MAP": VERSION_MAP,
"MIN_PYTHON_VERSION": MIN_PYTHON_VERSION,
"MAX_PYTHON_VERSION": MAX_PYTHON_VERSION,
"ADDITIONAL_DEPENDENCIES": [f"{item[0]}>={item[1]}" for item in additional_version_map.items()],
}
params |= {"options": self.code_model.options}
params |= kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ from {{ code_model.core_library }}.exceptions import DeserializationError
from {{ code_model.core_library }}{{ "" if code_model.is_azure_flavor else ".utils" }} import CaseInsensitiveEnumMeta
from {{ code_model.core_library }}.{{ "" if code_model.is_azure_flavor else "runtime." }}pipeline import PipelineResponse
from {{ code_model.core_library }}.serialization import _Null
{% if code_model.has_external_type %}
from {{ code_model.core_library }}.serialization import TypeHandlerRegistry
{% endif %}
from {{ code_model.core_library }}.rest import HttpResponse

_LOGGER = logging.getLogger(__name__)
Expand All @@ -34,6 +37,10 @@ __all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"]
TZ_UTC = timezone.utc
_T = typing.TypeVar("_T")

{% if code_model.has_external_type %}
TYPE_HANDLER_REGISTRY = TypeHandlerRegistry()
{% endif %}


def _timedelta_as_isostr(td: timedelta) -> str:
"""Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
Expand Down Expand Up @@ -158,6 +165,11 @@ class SdkJSONEncoder(JSONEncoder):
except AttributeError:
# This will be raised when it hits value.total_seconds in the method above
pass
{% if code_model.has_external_type %}
custom_serializer = TYPE_HANDLER_REGISTRY.get_serializer(o)
if custom_serializer:
return custom_serializer(o)
{% endif %}
return super(SdkJSONEncoder, self).default(o)


Expand Down Expand Up @@ -313,7 +325,13 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] =
return _deserialize_int_as_str
if rf and rf._format:
return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
{% if code_model.has_external_type %}
if _DESERIALIZE_MAPPING.get(annotation):
return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore
return TYPE_HANDLER_REGISTRY.get_deserializer(annotation) # pyright: ignore
{% else %}
return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore
{% endif %}


def _get_type_alias_type(module_name: str, alias_name: str):
Expand Down Expand Up @@ -507,6 +525,14 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
except AttributeError:
# This will be raised when it hits value.total_seconds in the method above
pass

{% if code_model.has_external_type %}
# Check if there's a custom serializer for the type
custom_serializer = TYPE_HANDLER_REGISTRY.get_serializer(o)
if custom_serializer:
return custom_serializer(o)
{% endif %}

return o


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ dependencies = [
"{{ dep }}",
{% endfor %}
{% endif %}
{% for dep in ADDITIONAL_DEPENDENCIES %}
"{{ dep }}",
{% endfor %}
]
dynamic = [
{% if options.get('package-mode') %}"version", {% endif %}"readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ setup(
"corehttp[requests]>={{ VERSION_MAP["corehttp"] }}",
{% endif %}
"typing-extensions>={{ VERSION_MAP['typing-extensions'] }}",
{% for dep in ADDITIONAL_DEPENDENCIES %}
{{ dep }},
{% endfor %}
],
{% if options["package-mode"] %}
python_requires=">={{ MIN_PYTHON_VERSION }}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ azure-mgmt-core==1.6.0
-e ./generated/azure-client-generator-core-usage
-e ./generated/azure-client-generator-core-override
-e ./generated/azure-client-generator-core-client-location
-e ./generated/azure-client-generator-core-alternate-type
-e ./generated/azure-core-basic
-e ./generated/azure-core-scalar
-e ./generated/azure-core-lro-rpc
Expand Down
Loading