The Jax primitive jax.lax.conv_transpose result does not match the result of torch.nn.functional.conv_transpose2d #32570
Replies: 4 comments 7 replies
-
I fed your question to Gemini 2.5 Pro, and it says (and for the record, I agree with it): Of course, I can help you solve this JAX issue. It's a very common point of confusion when working with transposed convolutions. The core of the problem lies in a misunderstanding of how jax.lax.conv_transpose interprets the dimensions of the kernel (your weight matrix W). Let's break it down. In a Nutshell: The SolutionYou should not transpose your weight matrix W. The error message you are seeing is correct and points to a channel mismatch created by your transpose operation. This is the correct implementation:
Did that answer your question? |
Beta Was this translation helpful? Give feedback.
-
gentle advice @SanjithKumar2, in your place I wouldn't leave full path on debugging message I share (even if it is generic), just the jax directory bit |
Beta Was this translation helpful? Give feedback.
-
@hawkinsp Thank you for taking the time to answer my question I tried this already as I mentioned, it did in-fact return the output of right shape (1, 1, 28, 28), but it did not match with the torch result, now to clarify the naming conventions once more Convolution forward pass:X-> (1, 1, 28, 28) (this is the actual input shape during the forward pass) (batch_size, in_channel, H, W) In this case the JAX and Torch results matched perfectly. import torch.nn.functional as F
import jax.lax as L
# Torch forward
Z = F.conv2d(X, W, stride=(2,2)) + 1e-12
# Jax forward
Z_ = L.conv_general_dilated(
lhs=X,
rhs=jnp.transpose(W, (2,3,1,0)),
window_strides = (2, 2),
padding='VALID',
dimension_numbers=('NCHW','HWIO','NCHW')
)
# Comparison
print(torch.allclose( torch.tensor(Z_), Z) # -> prints True also compared the forward pass with my manual implementation, both jax and torch matched Now during the backward relevance propagation (CONV TRANSPOSE)Y -> (1, 15, 13, 13) (this is the output relevance score and matches the output shape from the forward pass) (batch_size, out_channel, oH, oK) K, K_ -> (1, 1, 28, 28) (the expected output from conv transpose) (batch_size, in_channel (according to forward pass), H, W) To be clear this is how I see the dimensions and you can correct me here, In this case given your solution (from gemini 2.5 pro) did not match the torch results (the torch results matched my manually calculated results... that is why i said pytorch was correct) # Torch
K = F.conv_transpose2d(Y, W, stride=(2,2)) # output shape -> (1, 1, 28, 28)
# JAX
K_ = L.conv_transpose(
lhs = Y,
rhs = W,
strides = (2,2),
padding='VALID',
dimension_numbers=('NCHW','IOHW','NCHW')
) # output shape -> (1, 1, 28, 28)
# Comparison
print(torch.allclose(K, torch.tensor(K_)) # prints False When I check it against my manual implementation the pytorch results match, but not the JAX results I got traumatized by ChatGPT and Grok, and I do not have Pro versions for any llm, that is why I came here. Hope this helps you to understand more... |
Beta Was this translation helpful? Give feedback.
-
@hawkinsp I want the final print statement to print True import jax.numpy as jnp
import jax.lax as L
import torch
import torch.nn.functional as F
key = jax.random.PRNGKey(0)
# JAX
Y = jax.random.normal(key,(1,15,13,13))
W = jax.random.normal(key,(15, 1, 4, 4))
stride = (2,2)
padding = 'VALID'
K_ = L.conv_transpose(
lhs = Y,
rhs = W,
strides = stride,
padding=padding,
dimension_numbers=('NCHW','IOHW','NCHW')
)
# Torch
Y = torch.tensor(Y)
W = torch.tensor(W)
K = F.conv_transpose2d(input=Y,weight=W,stride=(2,2))
# Compare
K_ = torch.tensor(K_)
print(torch.allclose(K, K_)) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Description
I am new to Jax... I am facing this issue, the Jax based conv_transpose result does not match with pytorch and I am certain that the pytorch result is correct.
When I call this function with 👇
input.shape = (1, 15, 13, 13) description (batch, channels, H, W)
W.shape = (15, 1, 4, 4) description (out_channel, in_channel, Kh, Kw)
stride = (2,2)
padding='VALID'
I should get an output with shape (1, 1, 28, 28), instead I get the following error
If I do not do the transpose, I am not getting any errors but the result is wrong.
My understanding is the rhs dimension number 'IOHW' (in this case) must match as following
I -> input_channel
O -> out_channel
H -> height
W -> width
this is the reason I did the transpose, so after the transpose the weight matrix W is of shape (1, 15, 4, 4) instead of (15, 1, 4, 4)
I also tried the following.
This also threw the same error.
what am I missing?
System info (python version, jaxlib version, accelerator, etc.)
Beta Was this translation helpful? Give feedback.
All reactions