Skip to content

Commit 0811f0a

Browse files
committed
Add a codemod for dataclass changes in 3.11
Python 3.11 changed the field default mutability check for dataclasses to only allow defaults which are hashable. This codemod helps with the migration by changing all default values that aren't obviously hashable to use `default_factory` instead. Note: it's impossible to accurately determine if a particular expression produces a hashable value in a codemod, so the codemod significantly over-approximates what's unhashable.
1 parent d24192a commit 0811f0a

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
# pyre-strict
6+
import builtins
7+
8+
import libcst as cst
9+
from libcst import matchers as m
10+
from libcst.codemod import VisitorBasedCodemodCommand
11+
from libcst.codemod.visitors._add_imports import AddImportsVisitor
12+
from libcst.metadata.name_provider import QualifiedNameProvider
13+
from libcst.helpers import ensure_type
14+
15+
16+
def node_with_qname(expected: str) -> m.MatchMetadataIfTrue:
17+
return m.MatchMetadataIfTrue(
18+
QualifiedNameProvider,
19+
lambda qnames: any(qname.name == expected for qname in qnames),
20+
)
21+
22+
23+
annotation = m.SimpleStatementLine(body=[m.AtLeastN(m.AnnAssign(), n=1)])
24+
dataclass_ref = node_with_qname("dataclasses.dataclass")
25+
constant_expressions = [
26+
name for name in dir(builtins) if name.capitalize()[0] == name[0]
27+
]
28+
literal = (
29+
m.Integer()
30+
| m.Float()
31+
| m.Imaginary()
32+
| m.SimpleString()
33+
| m.ConcatenatedString()
34+
| m.FormattedString()
35+
)
36+
37+
field_call = m.Call(func=node_with_qname("dataclasses.field"))
38+
39+
default_arg = m.Call(
40+
args=[
41+
m.AtLeastN(
42+
m.Arg(
43+
keyword=m.Name("default"),
44+
value=m.SaveMatchedNode(m.DoNotCare(), "default_value"),
45+
),
46+
n=1,
47+
)
48+
],
49+
)
50+
51+
52+
def wrap_in_default_factory(expr: cst.BaseExpression) -> cst.Arg:
53+
return cst.Arg(
54+
keyword=cst.Name("default_factory"),
55+
value=cst.Lambda(params=cst.Parameters(), body=expr),
56+
)
57+
58+
59+
class DataclassDefaultFactoryCodemod(VisitorBasedCodemodCommand):
60+
"""
61+
Converts dataclass fields with mutable default values to use default_factory.
62+
63+
For example:
64+
@dataclass
65+
class Foo:
66+
x: list = [] # Mutable default, bad practice
67+
68+
Becomes:
69+
@dataclass
70+
class Foo:
71+
x: list = field(default_factory=lambda: []) # Better practice
72+
"""
73+
74+
METADATA_DEPENDENCIES = (QualifiedNameProvider,)
75+
76+
def is_immutable(self, expr: cst.BaseExpression) -> bool:
77+
return self.matches(
78+
expr,
79+
m.OneOf(
80+
literal,
81+
*[m.Name(expr) for expr in constant_expressions],
82+
),
83+
)
84+
85+
@m.leave(
86+
m.ClassDef(
87+
decorators=[m.AtLeastN(m.Decorator(dataclass_ref), n=1)],
88+
body=m.IndentedBlock(body=[m.AtLeastN(annotation, n=1)]),
89+
)
90+
)
91+
def handle_class(
92+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
93+
) -> cst.ClassDef:
94+
new_body: list[cst.BaseStatement] = []
95+
for line in ensure_type(original_node.body, cst.IndentedBlock).body:
96+
if not self.matches(line, annotation):
97+
new_body.append(line)
98+
continue
99+
new_line_body: list[cst.BaseSmallStatement] = []
100+
for stmt in ensure_type(line, cst.SimpleStatementLine).body:
101+
if not isinstance(stmt, cst.AnnAssign):
102+
new_line_body.append(stmt)
103+
continue
104+
new_line_body.append(self.handle_annotation(stmt))
105+
new_body.append(line.with_changes(body=new_line_body))
106+
107+
return updated_node.with_changes(
108+
body=updated_node.body.with_changes(body=new_body)
109+
)
110+
111+
def handle_annotation(self, annotation: cst.AnnAssign) -> cst.AnnAssign:
112+
if annotation.value is None or self.is_immutable(annotation.value):
113+
return annotation
114+
115+
if not self.matches(annotation.value, field_call):
116+
AddImportsVisitor.add_needed_import(self.context, "dataclasses", "field")
117+
return annotation.with_changes(
118+
value=cst.Call(
119+
func=cst.Name("field"),
120+
args=[wrap_in_default_factory(annotation.value)],
121+
)
122+
)
123+
124+
# we found field(...) on the RHS
125+
if (match := self.extract(annotation.value, default_arg)) is None:
126+
# no default= kwarg, nothing to do
127+
return annotation
128+
default = ensure_type(match["default_value"], cst.BaseExpression)
129+
if self.is_immutable(default):
130+
return annotation
131+
# rebuild the args for field(), dropping default= and adding default_factory=
132+
new_args: list[cst.Arg] = []
133+
for arg in ensure_type(annotation.value, cst.Call).args:
134+
if arg.keyword is None or arg.keyword.value != "default":
135+
new_args.append(arg)
136+
continue
137+
new_args.append(wrap_in_default_factory(default))
138+
return annotation.with_changes(
139+
value=annotation.value.with_changes(args=new_args)
140+
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
# pyre-strict
6+
7+
from libcst.codemod import CodemodTest
8+
from libcst.codemod.commands.dataclass_default_factory import (
9+
DataclassDefaultFactoryCodemod,
10+
)
11+
12+
13+
class TestDataclassCommand(CodemodTest):
14+
TRANSFORM = DataclassDefaultFactoryCodemod
15+
16+
def test_simple_immutable(self) -> None:
17+
before = """
18+
from dataclasses import dataclass
19+
@dataclass
20+
class Foo:
21+
x: int = 1
22+
y: bool = False
23+
z: str = "foo"
24+
"""
25+
self.assertCodemod(before, before)
26+
27+
def test_simple_mutable(self) -> None:
28+
before = """
29+
from dataclasses import dataclass
30+
@dataclass
31+
class Foo:
32+
x: list[int] = []
33+
y: foo = bar()
34+
"""
35+
after = """
36+
from dataclasses import field, dataclass
37+
@dataclass
38+
class Foo:
39+
x: list[int] = field(default_factory = lambda: [])
40+
y: foo = field(default_factory = lambda: bar())
41+
"""
42+
self.assertCodemod(before, after)
43+
44+
def test_idempotent(self) -> None:
45+
before = """
46+
from dataclasses import dataclass, field
47+
@dataclass
48+
class Foo:
49+
x: list[int] = field(default_factory=lambda: [])
50+
y: list[int] = field(repr=False)
51+
"""
52+
self.assertCodemod(before, before)
53+
54+
def test_field_with_default(self) -> None:
55+
before = """
56+
from dataclasses import dataclass, field
57+
@dataclass
58+
class Foo:
59+
x: list[int] = field(default=[])
60+
y: bool = field(default=True)
61+
"""
62+
after = """
63+
from dataclasses import dataclass, field
64+
@dataclass
65+
class Foo:
66+
x: list[int] = field(default_factory = lambda: [])
67+
y: bool = field(default=True)
68+
"""
69+
self.assertCodemod(before, after)

0 commit comments

Comments
 (0)