Skip to content

Conversation

@Velythyl
Copy link

@Velythyl Velythyl commented Jul 25, 2025

Someone should double check that the impl is correct, but my 3D autoencoder does reconstruct the correct shapes with these implementations

@Velythyl Velythyl changed the title Not sure that they're 100% correct but here's relu, conv3d, conv3d_transpose Not sure that they're 100% correct but here's leaky_relu, conv3d, conv3d_transpose Jul 25, 2025
@Velythyl Velythyl changed the title Not sure that they're 100% correct but here's leaky_relu, conv3d, conv3d_transpose Not sure that they're 100% correct but here's implementations for leaky_relu, conv3d, conv3d_transpose Jul 25, 2025
Copy link
Owner

@samuela samuela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for submitting this @Velythyl ! Overall this looks good. In order to check that these implementations are indeed correct, let's add some tests. We have some existing ones that you can build off of with a little modification. See eg the Conv2d test here:

def test_torch_nn_Conv2d():
for out_channels in [4, 6, 8]:
for bias in [False, True]:
for stride in [1, (1, 1), 2, (2, 2), (1, 3), (3, 1)]:
for padding in [0, 1, 2, "valid", "same", (1, 2)]:
for dilation in [1, 2, (1, 1), (2, 3)]:
for groups in [1, 2]:
if padding == "same":
# ValueError: padding='same' is not supported for strided convolutions
stride = 1
model = torch.nn.Conv2d(
2, out_channels, (5, 5), bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups
)
input_batch = 0.1 * random.normal(random.PRNGKey(123), (7, 2, 16, 16))
params = {k: 0.1 * random.normal(random.PRNGKey(123), v.shape) for k, v in model.named_parameters()}
model.load_state_dict({k: j2t(v) for k, v in params.items()})
res_torch = model(j2t(input_batch))
jaxified_module = t2j(model)
res_jax = jaxified_module(input_batch, state_dict=params)
res_jax_jit = jit(jaxified_module)(input_batch, state_dict=params)
# Test forward pass with and without jax.jit
aac(res_jax, res_torch.numpy(force=True), atol=1e-5)
aac(res_jax_jit, res_torch.numpy(force=True), atol=1e-5)
# Test gradients
jax_grad = grad(lambda p: (jaxified_module(input_batch, state_dict=p) ** 2).sum())(params)
res_torch.pow(2).sum().backward()
aac(jax_grad["weight"], model.weight.grad, atol=1e-4)
if bias:
aac(jax_grad["bias"], model.bias.grad, atol=1e-3)



@implements(torch.nn.functional.leaky_relu, Torchishify_output=False)
def leaky_relu(x, negative_slope=0.01, inplace=False):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def leaky_relu(x, negative_slope=0.01, inplace=False):
def leaky_relu(input, negative_slope=0.01, inplace=False):

so that leaky_relu(input=foo) works, consistent with torch's argument names

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants