Skip to content

Commit 2fd35c8

Browse files
committed
fix bug
1 parent 677e145 commit 2fd35c8

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

src/snowflake/snowpark/_internal/xml_reader.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -869,13 +869,12 @@ def infer_schema(
869869
# Attributes (same rule as element_to_dict_or_str)
870870
if not exclude_attributes:
871871
for attr_name, attr_value in element.attrib.items():
872-
fields.append(
873-
StructField(
874-
f"{attribute_prefix}{attr_name}",
875-
infer_type(attr_value, ignore_surrounding_whitespace, null_value),
876-
True,
877-
)
872+
field = StructField(
873+
f"'{attribute_prefix}{attr_name}'",
874+
infer_type(attr_value, ignore_surrounding_whitespace, null_value),
875+
True,
878876
)
877+
fields.append(field)
879878

880879
# Children
881880
if children:
@@ -899,19 +898,19 @@ def infer_schema(
899898
assert dt is not None
900899
if len(elems) > 1:
901900
dt = ArrayType(dt)
902-
fields.append(StructField(tag, dt, True))
901+
field = StructField(f"'{tag}'", dt, True)
902+
fields.append(field)
903903
else:
904904
# No children, but has attributes -> also include the value_tag for text if present and not null
905905
# (matches element_to_dict_or_str behavior)
906906
t = norm_text(ignore_surrounding_whitespace, element.text)
907907
if t is not None and t != null_value:
908-
fields.append(
909-
StructField(
910-
value_tag,
911-
infer_type(t, ignore_surrounding_whitespace, null_value),
912-
True,
913-
)
908+
field = StructField(
909+
f"'{value_tag}'",
910+
infer_type(t, ignore_surrounding_whitespace, null_value),
911+
True,
914912
)
913+
fields.append(field)
915914

916915
return StructType(fields)
917916

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,14 +1076,20 @@ def xml(self, path: str, _emit_ast: bool = True) -> DataFrame:
10761076

10771077
# cast to input custom schema type
10781078
# TODO: SNOW-2923003: remove single quote after server side BCR is done
1079-
if self._user_schema or self._infer_schema:
1079+
if self._user_schema and not self._infer_schema:
10801080
cols = [
10811081
df[single_quote(field._name)]
10821082
.cast(field.datatype)
10831083
.alias(quote_name_without_upper_casing(field._name))
10841084
for field in self._user_schema.fields
10851085
]
10861086
return df.select(cols)
1087+
elif self._infer_schema:
1088+
cols = [
1089+
df[field._name].cast(field.datatype).alias(field._name)
1090+
for field in self._user_schema.fields
1091+
]
1092+
return df.select(cols)
10871093
else:
10881094
return df
10891095

0 commit comments

Comments
 (0)