@@ -55,34 +55,11 @@ def __init__(self, data):
5555 with self .assertRaisesRegex (NotImplementedError , "arg_types" ):
5656 l .weight = torch .nn .Parameter (MyTensor (l .weight ))
5757
58- @skip_if_no_cuda ()
59- def test_default_impls (self ):
60- """Making sure some common functions has default implementations, such as
61- __tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
62- """
63-
64- class MyTensor (TorchAOBaseTensor ):
65- tensor_data_names = ["qdata" ]
66- tensor_attribute_names = ["attr" , "device" ]
67-
68- def __new__ (cls , qdata , attr , device = None ):
69- shape = qdata .shape
70- if device is None :
71- device = qdata .device
72- kwargs = {"device" : device }
73- return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
74-
75- def __init__ (self , qdata , attr , device = None ):
76- self .qdata = qdata
77- self .attr = attr
78-
79- l = torch .nn .Linear (2 , 3 )
80- l .weight = torch .nn .Parameter (MyTensor (l .weight , "attr" ))
81- lp_tensor = l .weight
58+ def _test_default_impls_helper (self , lp_tensor , lp_tensor_for_copy ):
8259 # test __tensor_flatten__ and __tensor_unflatten__
83- tensor_data_name_dict , tensor_attributes = lp_tensor .__tensor_flatten__ ()
60+ tensor_data_names , tensor_attributes = lp_tensor .__tensor_flatten__ ()
8461 tensor_data_dict = {
85- name : getattr (lp_tensor , name ) for name in tensor_data_name_dict
62+ name : getattr (lp_tensor , name ) for name in tensor_data_names
8663 }
8764 outer_size = lp_tensor .size ()
8865 outer_stride = lp_tensor .stride ()
@@ -107,24 +84,102 @@ def __init__(self, qdata, attr, device=None):
10784 # explicitly testing aten.alias
10885 lp_tensor = torch .ops .aten .alias (lp_tensor )
10986 lp_tensor = lp_tensor .clone ()
110- # making qdata not contiguous
111- lp_tensor .qdata = lp_tensor .qdata .transpose (0 , 1 ).contiguous ()
112- lp_tensor .qdata = lp_tensor .qdata .transpose (0 , 1 )
113- self .assertFalse (lp_tensor .qdata .is_contiguous ())
114- lp_tensor = lp_tensor .contiguous ()
115- # making sure contiguous call works
116- self .assertTrue (lp_tensor .qdata .is_contiguous ())
87+ # get all tensor_data_names for both
88+ # non optional and valid optional tensors
89+ tensor_data_names = lp_tensor .tensor_data_names .copy ()
90+ if hasattr (lp_tensor , "optional_tensor_data_names" ):
91+ for tensor_data_name in lp_tensor .optional_tensor_data_names :
92+ if getattr (lp_tensor , tensor_data_name ) is not None :
93+ tensor_data_names .append (tensor_data_name )
94+
95+ # for each of the tensor data, we try to
96+ # make it non-contiguous and then use
97+ # lp_tensor.contiguous() call to make sure
98+ # contiguous() works
99+ for tensor_data_name in tensor_data_names :
100+ tensor = getattr (lp_tensor , tensor_data_name )
101+ # making qdata not contiguous
102+ tensor = tensor .transpose (0 , 1 ).contiguous ()
103+ tensor = tensor .transpose (0 , 1 )
104+ setattr (lp_tensor , tensor_data_name , tensor )
105+ self .assertFalse (getattr (lp_tensor , tensor_data_name ).is_contiguous ())
106+ lp_tensor = lp_tensor .contiguous ()
107+ # making sure contiguous call works
108+ self .assertTrue (getattr (lp_tensor , tensor_data_name ).is_contiguous ())
117109
118110 # copy_
119- another_tensor = torch .nn .Linear (2 , 3 ).weight
120- # attribute has to be the same
121- another_lp_tensor = MyTensor (another_tensor , "attr" )
122111 # initially tensor values are not the same
123- self .assertNotEqual (lp_tensor .qdata [0 ][0 ], another_lp_tensor .qdata [0 ][0 ])
124- lp_tensor .copy_ (another_lp_tensor )
112+ self .assertNotEqual (lp_tensor .qdata [0 ][0 ], lp_tensor_for_copy .qdata [0 ][0 ])
113+ lp_tensor .copy_ (lp_tensor_for_copy )
125114 self .assertEqual (lp_tensor .attr , "attr" )
126115 # after copy_, the tensor values should match
127- self .assertEqual (lp_tensor .qdata [0 ][0 ], another_lp_tensor .qdata [0 ][0 ])
116+ self .assertEqual (lp_tensor .qdata [0 ][0 ], lp_tensor_for_copy .qdata [0 ][0 ])
117+
118+ @skip_if_no_cuda ()
119+ def test_default_impls (self ):
120+ """Making sure some common functions has default implementations, such as
121+ __tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
122+ """
123+
124+ class MyTensor (TorchAOBaseTensor ):
125+ tensor_data_names = ["qdata" ]
126+ tensor_attribute_names = ["attr" , "device" ]
127+
128+ def __new__ (cls , qdata , attr , device = None ):
129+ shape = qdata .shape
130+ if device is None :
131+ device = qdata .device
132+ kwargs = {"device" : device }
133+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
134+
135+ def __init__ (self , qdata , attr , device = None ):
136+ self .qdata = qdata
137+ self .attr = attr
138+
139+ l = torch .nn .Linear (2 , 3 )
140+ l .weight = torch .nn .Parameter (MyTensor (l .weight , "attr" ))
141+ lp_tensor = l .weight
142+
143+ another_tensor = torch .nn .Linear (2 , 3 ).weight
144+ # attribute has to be the same
145+ lp_tensor_for_copy = MyTensor (another_tensor , "attr" )
146+ self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
147+
148+ def test_default_impls_with_optional_data (self ):
149+ class MyTensorWithOptionalData (TorchAOBaseTensor ):
150+ tensor_data_names = ["qdata" ]
151+ optional_tensor_data_names = ["zero_point" ]
152+ tensor_attribute_names = ["attr" , "device" ]
153+
154+ def __new__ (cls , qdata , zero_point = None , attr = 1.0 , device = None ):
155+ shape = qdata .shape
156+ if device is None :
157+ device = qdata .device
158+ kwargs = {"device" : device }
159+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
160+
161+ def __init__ (self , qdata , zero_point = None , attr = 1.0 , device = None ):
162+ self .qdata = qdata
163+ self .zero_point = zero_point
164+ self .attr = attr
165+
166+ # test both the optional Tensor is None
167+ # and not None
168+ l = torch .nn .Linear (2 , 3 )
169+ lp_tensor = MyTensorWithOptionalData (l .weight , None , "attr" )
170+ l = torch .nn .Linear (2 , 3 )
171+ lp_tensor_for_copy = MyTensorWithOptionalData (l .weight , None , "attr" )
172+ self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
173+
174+ l = torch .nn .Linear (2 , 3 )
175+ lp_tensor = MyTensorWithOptionalData (
176+ l .weight , torch .zeros_like (l .weight ), "attr"
177+ )
178+ l = torch .nn .Linear (2 , 3 )
179+ lp_tensor_for_copy = MyTensorWithOptionalData (
180+ l .weight , torch .zeros_like (l .weight ), "attr"
181+ )
182+ self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
128183
129184
130185if __name__ == "__main__" :
0 commit comments