-
Notifications
You must be signed in to change notification settings - Fork 14
Not sure that they're 100% correct but here's implementations for leaky_relu, conv3d, conv3d_transpose #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
samuela
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for submitting this @Velythyl ! Overall this looks good. In order to check that these implementations are indeed correct, let's add some tests. We have some existing ones that you can build off of with a little modification. See eg the Conv2d test here:
torch2jax/tests/test_torch_nn.py
Lines 92 to 126 in 7132ab7
| def test_torch_nn_Conv2d(): | |
| for out_channels in [4, 6, 8]: | |
| for bias in [False, True]: | |
| for stride in [1, (1, 1), 2, (2, 2), (1, 3), (3, 1)]: | |
| for padding in [0, 1, 2, "valid", "same", (1, 2)]: | |
| for dilation in [1, 2, (1, 1), (2, 3)]: | |
| for groups in [1, 2]: | |
| if padding == "same": | |
| # ValueError: padding='same' is not supported for strided convolutions | |
| stride = 1 | |
| model = torch.nn.Conv2d( | |
| 2, out_channels, (5, 5), bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups | |
| ) | |
| input_batch = 0.1 * random.normal(random.PRNGKey(123), (7, 2, 16, 16)) | |
| params = {k: 0.1 * random.normal(random.PRNGKey(123), v.shape) for k, v in model.named_parameters()} | |
| model.load_state_dict({k: j2t(v) for k, v in params.items()}) | |
| res_torch = model(j2t(input_batch)) | |
| jaxified_module = t2j(model) | |
| res_jax = jaxified_module(input_batch, state_dict=params) | |
| res_jax_jit = jit(jaxified_module)(input_batch, state_dict=params) | |
| # Test forward pass with and without jax.jit | |
| aac(res_jax, res_torch.numpy(force=True), atol=1e-5) | |
| aac(res_jax_jit, res_torch.numpy(force=True), atol=1e-5) | |
| # Test gradients | |
| jax_grad = grad(lambda p: (jaxified_module(input_batch, state_dict=p) ** 2).sum())(params) | |
| res_torch.pow(2).sum().backward() | |
| aac(jax_grad["weight"], model.weight.grad, atol=1e-4) | |
| if bias: | |
| aac(jax_grad["bias"], model.bias.grad, atol=1e-3) |
|
|
||
|
|
||
| @implements(torch.nn.functional.leaky_relu, Torchishify_output=False) | ||
| def leaky_relu(x, negative_slope=0.01, inplace=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def leaky_relu(x, negative_slope=0.01, inplace=False): | |
| def leaky_relu(input, negative_slope=0.01, inplace=False): |
so that leaky_relu(input=foo) works, consistent with torch's argument names
Someone should double check that the impl is correct, but my 3D autoencoder does reconstruct the correct shapes with these implementations