@@ -38,18 +38,12 @@ class ConstantVariable(VariableTracker):
38
38
@staticmethod
39
39
def create (value , ** kwargs ) -> VariableTracker :
40
40
source = kwargs .get ("source" , None )
41
- is_literal = ConstantVariable .is_literal (value )
42
- if not is_literal :
43
- for disallowed_type , reason in _type_to_assert_reason .items ():
44
- assert not isinstance (value , disallowed_type ), reason
45
41
46
- # Routing for list and tuple literals.
47
- if is_literal and isinstance (value , (set , frozenset )):
48
- items = []
49
- for i , x in enumerate (value ):
50
- items .append (ConstantVariable .create (x ))
42
+ # Routing for supported collection literals.
43
+ if isinstance (value , (set , frozenset )):
44
+ items = [ConstantVariable .create (x ) for x in value ]
51
45
return variables .SetVariable (items , ** kwargs )
52
- elif is_literal and isinstance (value , (list , tuple )):
46
+ elif isinstance (value , (list , tuple )):
53
47
items = []
54
48
for i , x in enumerate (value ):
55
49
item_source = GetItemSource (source , i ) if source else None
@@ -67,13 +61,10 @@ def create(value, **kwargs) -> VariableTracker:
67
61
68
62
def __init__ (self , value , ** kwargs ) -> None :
69
63
super ().__init__ (** kwargs )
70
- if not ConstantVariable .is_literal (value ):
64
+ if not ConstantVariable .is_base_literal (value ):
71
65
for disallowed_type , reason in _type_to_assert_reason .items ():
72
66
assert not isinstance (value , disallowed_type ), reason
73
67
74
- assert not isinstance (
75
- value , (list , tuple )
76
- ), "ConstantVariable(list) is banned - please create a ListVariable(items)"
77
68
if np is not None and isinstance (value , np .number ):
78
69
self .value = value .item ()
79
70
else :
@@ -104,14 +95,15 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
104
95
self .value [arg .as_python_constant ()],
105
96
)
106
97
98
+ @staticmethod
99
+ def is_base_literal (obj ):
100
+ return type (obj ) in common_constant_types
101
+
107
102
@staticmethod
108
103
def is_literal (obj ):
109
- if type (obj ) in common_constant_types :
110
- return True
111
- # The structure within is_literal get routed to variables.BaseListVariable
112
104
if type (obj ) in (list , tuple , set , frozenset , torch .Size ):
113
105
return all (ConstantVariable .is_literal (x ) for x in obj )
114
- return False
106
+ return ConstantVariable . is_base_literal ( obj )
115
107
116
108
def unpack_var_sequence (self , tx ):
117
109
try :
0 commit comments