@@ -90,36 +90,40 @@ def test_torch_nn_BatchNorm1d():
9090
9191
9292def test_torch_nn_Conv2d ():
93- for bias in [False , True ]:
94- for stride in [1 , (1 , 1 ), 2 , (2 , 2 ), (1 , 3 ), (3 , 1 )]:
95- for padding in [0 , 1 , 2 , "valid" , "same" , (1 , 2 )]:
96- for dilation in [1 , 2 , (1 , 1 ), (2 , 3 )]:
97- if padding == "same" :
98- # ValueError: padding='same' is not supported for strided convolutions
99- stride = 1
100- model = torch .nn .Conv2d (2 , 3 , (5 , 5 ), bias = bias , stride = stride , padding = padding , dilation = dilation )
101-
102- input_batch = 0.1 * random .normal (random .PRNGKey (123 ), (7 , 2 , 16 , 16 ))
103- params = {k : 0.1 * random .normal (random .PRNGKey (123 ), v .shape ) for k , v in model .named_parameters ()}
104-
105- model .load_state_dict ({k : j2t (v ) for k , v in params .items ()})
106- res_torch = model (j2t (input_batch ))
107-
108- jaxified_module = t2j (model )
109- res_jax = jaxified_module (input_batch , state_dict = params )
110- res_jax_jit = jit (jaxified_module )(input_batch , state_dict = params )
111-
112- # Test forward pass with and without jax.jit
113- aac (res_jax , res_torch .numpy (force = True ), atol = 1e-5 )
114- aac (res_jax_jit , res_torch .numpy (force = True ), atol = 1e-5 )
115-
116- # Test gradients
117- jax_grad = grad (lambda p : (jaxified_module (input_batch , state_dict = p ) ** 2 ).sum ())(params )
118-
119- res_torch .pow (2 ).sum ().backward ()
120- aac (jax_grad ["weight" ], model .weight .grad , atol = 1e-4 )
121- if bias :
122- aac (jax_grad ["bias" ], model .bias .grad , atol = 1e-3 )
93+ for out_channels in [4 , 6 , 8 ]:
94+ for bias in [False , True ]:
95+ for stride in [1 , (1 , 1 ), 2 , (2 , 2 ), (1 , 3 ), (3 , 1 )]:
96+ for padding in [0 , 1 , 2 , "valid" , "same" , (1 , 2 )]:
97+ for dilation in [1 , 2 , (1 , 1 ), (2 , 3 )]:
98+ for groups in [1 , 2 ]:
99+ if padding == "same" :
100+ # ValueError: padding='same' is not supported for strided convolutions
101+ stride = 1
102+ model = torch .nn .Conv2d (
103+ 2 , out_channels , (5 , 5 ), bias = bias , stride = stride , padding = padding , dilation = dilation , groups = groups
104+ )
105+
106+ input_batch = 0.1 * random .normal (random .PRNGKey (123 ), (7 , 2 , 16 , 16 ))
107+ params = {k : 0.1 * random .normal (random .PRNGKey (123 ), v .shape ) for k , v in model .named_parameters ()}
108+
109+ model .load_state_dict ({k : j2t (v ) for k , v in params .items ()})
110+ res_torch = model (j2t (input_batch ))
111+
112+ jaxified_module = t2j (model )
113+ res_jax = jaxified_module (input_batch , state_dict = params )
114+ res_jax_jit = jit (jaxified_module )(input_batch , state_dict = params )
115+
116+ # Test forward pass with and without jax.jit
117+ aac (res_jax , res_torch .numpy (force = True ), atol = 1e-5 )
118+ aac (res_jax_jit , res_torch .numpy (force = True ), atol = 1e-5 )
119+
120+ # Test gradients
121+ jax_grad = grad (lambda p : (jaxified_module (input_batch , state_dict = p ) ** 2 ).sum ())(params )
122+
123+ res_torch .pow (2 ).sum ().backward ()
124+ aac (jax_grad ["weight" ], model .weight .grad , atol = 1e-4 )
125+ if bias :
126+ aac (jax_grad ["bias" ], model .bias .grad , atol = 1e-3 )
123127
124128
125129def test_torch_nn_ConvTranspose2d ():
0 commit comments