@@ -297,6 +297,28 @@ def test_torch_nn_MaxPool2d():
297297 aac (jax_grad , x .grad )
298298
299299
300+ def test_torch_nn_PReLU ():
301+ model = torch .nn .PReLU (3 )
302+ input_batch = random .normal (random .PRNGKey (123 ), (1 , 3 , 112 , 112 ))
303+ params = {k : 0.1 * random .normal (random .PRNGKey (123 ), v .shape ) for k , v in model .named_parameters ()}
304+ model .load_state_dict ({k : j2t (v ) for k , v in params .items ()})
305+ res_torch = model (j2t (input_batch ))
306+
307+ jaxified_module = t2j (model )
308+ res_jax = jaxified_module (input_batch , state_dict = params )
309+ res_jax_jit = jit (jaxified_module )(input_batch , state_dict = params )
310+
311+ # Test forward pass without and with jit
312+ aac (res_jax , res_torch .numpy (force = True ), atol = 1e-5 )
313+ aac (res_jax_jit , res_torch .numpy (force = True ), atol = 1e-5 )
314+
315+ # Test gradients
316+ jax_grad = grad (lambda p : (jaxified_module (input_batch , state_dict = p ) ** 2 ).sum ())(params )
317+
318+ res_torch .pow (2 ).sum ().backward ()
319+ aac (jax_grad ["weight" ], model .weight .grad , atol = 1e-3 )
320+
321+
300322################################################################################
301323# torch.nn.functional
302324
@@ -323,6 +345,11 @@ def f(input, running_mean, running_var, weight, bias):
323345 t2j_function_test (f , [(2 , 3 , 5 , 7 ), (3 ,), (3 ,), (3 ,), (3 ,)], atol = 1e-6 )
324346
325347
348+ def test_torch_nn_functional_prelu ():
349+ t2j_function_test (torch .nn .functional .prelu , [(6 , 6 ), (1 )], atol = 1e-6 )
350+ t2j_function_test (torch .nn .functional .prelu , [(5 , 3 , 112 , 122 ), (3 ,)], atol = 1e-6 )
351+
352+
326353def test_torch_nn_functional_scaled_dot_product_attention ():
327354 t2j_function_test (lambda x , y : x @ y , [(2 , 3 , 5 ), (5 , 7 )], atol = 1e-6 )
328355
0 commit comments