2727
2828
2929@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
30- @pytest .mark .parametrize ("randomized " , (True , False ))
30+ @pytest .mark .parametrize ("randomize " , (True , False ))
3131@pytest .mark .parametrize ("head_dim" , (None , 2 , 4 ))
3232@pytest .mark .parametrize ("input_batch_size" , (1 , 5 , 17 ))
33- def test_correctness_linear (type , randomized , head_dim , input_batch_size ):
33+ def test_correctness_linear (type , randomize , head_dim , input_batch_size ):
3434 size = (4 , 8 )
3535 module = torch .nn .Linear (* size , bias = False )
36- scheme = TransformScheme (type = type , randomized = randomized , head_dim = head_dim )
36+ scheme = TransformScheme (type = type , randomize = randomize , head_dim = head_dim )
3737 factory = TransformFactory .from_scheme (scheme , name = "" )
3838
3939 input_tfm = factory .create_transform (
@@ -58,10 +58,10 @@ def test_correctness_linear(type, randomized, head_dim, input_batch_size):
5858
5959
6060@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
61- @pytest .mark .parametrize ("randomized " , (True , False ))
61+ @pytest .mark .parametrize ("randomize " , (True , False ))
6262@pytest .mark .parametrize ("embed_loc" , ("weight_output" , "output" ))
6363@pytest .mark .parametrize ("linear_loc" , ("input" , "weight_input" ))
64- def test_correctness_embedding (type , randomized , embed_loc , linear_loc ):
64+ def test_correctness_embedding (type , randomize , embed_loc , linear_loc ):
6565 model = torch .nn .Sequential (
6666 torch .nn .Embedding (2 , 4 ),
6767 torch .nn .Linear (4 , 8 , bias = False ),
@@ -74,7 +74,7 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
7474 config_groups = {
7575 "" : TransformScheme (
7676 type = type ,
77- randomized = randomized ,
77+ randomize = randomize ,
7878 apply = [
7979 TransformArgs (targets = "Embedding" , location = embed_loc ),
8080 TransformArgs (targets = "Linear" , location = linear_loc , inverse = True ),
@@ -90,10 +90,10 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
9090
9191
9292@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
93- @pytest .mark .parametrize ("randomized " , (True , False ))
93+ @pytest .mark .parametrize ("randomize " , (True , False ))
9494@pytest .mark .parametrize ("input_batch_size" , (1 , 5 , 17 ))
9595def test_correctness_model (
96- type , randomized , input_batch_size , model_apply , offload = False
96+ type , randomize , input_batch_size , model_apply , offload = False
9797):
9898 # load model
9999 model = model_apply [0 ]
@@ -109,7 +109,7 @@ def test_correctness_model(
109109 # apply transforms
110110 config = TransformConfig (
111111 config_groups = {
112- "" : TransformScheme (type = type , randomized = randomized , apply = model_apply [1 ])
112+ "" : TransformScheme (type = type , randomize = randomize , apply = model_apply [1 ])
113113 }
114114 )
115115 apply_transform_config (model , config )
@@ -122,19 +122,17 @@ def test_correctness_model(
122122@requires_gpu
123123@requires_accelerate ()
124124@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
125- @pytest .mark .parametrize ("randomized " , (True , False ))
125+ @pytest .mark .parametrize ("randomize " , (True , False ))
126126@pytest .mark .parametrize ("input_batch_size" , (1 , 5 , 17 ))
127- def test_correctness_model_offload (type , randomized , input_batch_size , model_apply ):
128- test_correctness_model (
129- type , randomized , input_batch_size , model_apply , offload = True
130- )
127+ def test_correctness_model_offload (type , randomize , input_batch_size , model_apply ):
128+ test_correctness_model (type , randomize , input_batch_size , model_apply , offload = True )
131129
132130
133131@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
134- @pytest .mark .parametrize ("randomized " , (True , False ))
132+ @pytest .mark .parametrize ("randomize " , (True , False ))
135133@pytest .mark .parametrize ("head_dim" , (4 , 8 ))
136134@pytest .mark .parametrize ("input_batch_size" , (1 , 5 , 17 ))
137- def test_correctness_attention_heads (type , randomized , head_dim , input_batch_size ):
135+ def test_correctness_attention_heads (type , randomize , head_dim , input_batch_size ):
138136 hidden_size = 64
139137 num_attention_heads = 8
140138
@@ -151,7 +149,7 @@ def test_correctness_attention_heads(type, randomized, head_dim, input_batch_siz
151149 config_groups = {
152150 "" : TransformScheme (
153151 type = type ,
154- randomized = randomized ,
152+ randomize = randomize ,
155153 head_dim = head_dim ,
156154 apply = [
157155 TransformArgs (targets = "v_proj" , location = "weight_output" ),
0 commit comments