Skip to content

Commit a59baaa

Browse files
StrongerXipobin6
authored andcommitted
[dynamo] Simplify ConstantVariable.create and ConstantVariable.__init__ (pytorch#140745)
This patch removes some redundant code paths in `ConstantVariable.create` and` ConstantVariable.__init__`. Closes pytorch#110871. Pull Request resolved: pytorch#140745 Approved by: https://github.com/jansel
1 parent cc8cd91 commit a59baaa

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

torch/_dynamo/variables/constant.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,12 @@ class ConstantVariable(VariableTracker):
3838
@staticmethod
3939
def create(value, **kwargs) -> VariableTracker:
4040
source = kwargs.get("source", None)
41-
is_literal = ConstantVariable.is_literal(value)
42-
if not is_literal:
43-
for disallowed_type, reason in _type_to_assert_reason.items():
44-
assert not isinstance(value, disallowed_type), reason
4541

46-
# Routing for list and tuple literals.
47-
if is_literal and isinstance(value, (set, frozenset)):
48-
items = []
49-
for i, x in enumerate(value):
50-
items.append(ConstantVariable.create(x))
42+
# Routing for supported collection literals.
43+
if isinstance(value, (set, frozenset)):
44+
items = [ConstantVariable.create(x) for x in value]
5145
return variables.SetVariable(items, **kwargs)
52-
elif is_literal and isinstance(value, (list, tuple)):
46+
elif isinstance(value, (list, tuple)):
5347
items = []
5448
for i, x in enumerate(value):
5549
item_source = GetItemSource(source, i) if source else None
@@ -67,13 +61,10 @@ def create(value, **kwargs) -> VariableTracker:
6761

6862
def __init__(self, value, **kwargs) -> None:
6963
super().__init__(**kwargs)
70-
if not ConstantVariable.is_literal(value):
64+
if not ConstantVariable.is_base_literal(value):
7165
for disallowed_type, reason in _type_to_assert_reason.items():
7266
assert not isinstance(value, disallowed_type), reason
7367

74-
assert not isinstance(
75-
value, (list, tuple)
76-
), "ConstantVariable(list) is banned - please create a ListVariable(items)"
7768
if np is not None and isinstance(value, np.number):
7869
self.value = value.item()
7970
else:
@@ -104,14 +95,15 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
10495
self.value[arg.as_python_constant()],
10596
)
10697

98+
@staticmethod
99+
def is_base_literal(obj):
100+
return type(obj) in common_constant_types
101+
107102
@staticmethod
108103
def is_literal(obj):
109-
if type(obj) in common_constant_types:
110-
return True
111-
# The structure within is_literal get routed to variables.BaseListVariable
112104
if type(obj) in (list, tuple, set, frozenset, torch.Size):
113105
return all(ConstantVariable.is_literal(x) for x in obj)
114-
return False
106+
return ConstantVariable.is_base_literal(obj)
115107

116108
def unpack_var_sequence(self, tx):
117109
try:

0 commit comments

Comments
 (0)