Skip to content

Commit 7132ab7

Browse files
authored
Merge pull request #11 from skutaada/main
Adding torch.nn.PReLU
2 parents dd705ed + 2b29ce3 commit 7132ab7

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

tests/test_torch_nn.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,28 @@ def test_torch_nn_MaxPool2d():
297297
aac(jax_grad, x.grad)
298298

299299

300+
def test_torch_nn_PReLU():
301+
model = torch.nn.PReLU(3)
302+
input_batch = random.normal(random.PRNGKey(123), (1, 3, 112, 112))
303+
params = {k: 0.1 * random.normal(random.PRNGKey(123), v.shape) for k, v in model.named_parameters()}
304+
model.load_state_dict({k: j2t(v) for k, v in params.items()})
305+
res_torch = model(j2t(input_batch))
306+
307+
jaxified_module = t2j(model)
308+
res_jax = jaxified_module(input_batch, state_dict=params)
309+
res_jax_jit = jit(jaxified_module)(input_batch, state_dict=params)
310+
311+
# Test forward pass without and with jit
312+
aac(res_jax, res_torch.numpy(force=True), atol=1e-5)
313+
aac(res_jax_jit, res_torch.numpy(force=True), atol=1e-5)
314+
315+
# Test gradients
316+
jax_grad = grad(lambda p: (jaxified_module(input_batch, state_dict=p) ** 2).sum())(params)
317+
318+
res_torch.pow(2).sum().backward()
319+
aac(jax_grad["weight"], model.weight.grad, atol=1e-3)
320+
321+
300322
################################################################################
301323
# torch.nn.functional
302324

@@ -323,6 +345,11 @@ def f(input, running_mean, running_var, weight, bias):
323345
t2j_function_test(f, [(2, 3, 5, 7), (3,), (3,), (3,), (3,)], atol=1e-6)
324346

325347

348+
def test_torch_nn_functional_prelu():
349+
t2j_function_test(torch.nn.functional.prelu, [(6, 6), (1)], atol=1e-6)
350+
t2j_function_test(torch.nn.functional.prelu, [(5, 3, 112, 122), (3,)], atol=1e-6)
351+
352+
326353
def test_torch_nn_functional_scaled_dot_product_attention():
327354
t2j_function_test(lambda x, y: x @ y, [(2, 3, 5), (5, 7)], atol=1e-6)
328355

torch2jax/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,26 @@ def relu(x, inplace=False):
915915
return Torchish(jax.nn.relu(_v(x)))
916916

917917

918+
@implements(torch.nn.functional.prelu)
919+
def prelu(input: Torchish, weight: Torchish):
920+
if weight.numel() != 1:
921+
assert input.ndim > 0, "Not allow zero-dim input tensor."
922+
channel_size = input.shape[1] if input.ndim >= 2 else 1
923+
assert weight.numel() == channel_size, (
924+
f"Mismatch of parameter numbers and input channel size. Found parameter numbers = {weight.numel()} and channel size = {channel_size}."
925+
)
926+
assert weight.ndim == 0 or weight.ndim == 1, (
927+
f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = {weight.ndim}"
928+
)
929+
if input.ndim == 0:
930+
weight = weight[0] if weight.ndim == 1 else weight
931+
else:
932+
weight = Torchish(
933+
jax.lax.broadcast_in_dim(_v(weight), input.shape, () if weight.ndim == 0 else (0 if input.ndim == 1 else 1,))
934+
)
935+
return jnp.where(_v(input) > 0, _v(input), _v(input) * _v(weight))
936+
937+
918938
@implements(torch.nn.functional.scaled_dot_product_attention)
919939
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
920940
assert attn_mask is None, "TODO: implement attn_mask"

0 commit comments

Comments
 (0)