Skip to content

Commit dd705ed

Browse files
authored
Added support for groups in conv2d (#12)
Co-authored-by: Adam Skuta <[email protected]>
1 parent 73c1845 commit dd705ed

File tree

2 files changed

+35
-32
lines changed

2 files changed

+35
-32
lines changed

tests/test_torch_nn.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -90,36 +90,40 @@ def test_torch_nn_BatchNorm1d():
9090

9191

9292
def test_torch_nn_Conv2d():
93-
for bias in [False, True]:
94-
for stride in [1, (1, 1), 2, (2, 2), (1, 3), (3, 1)]:
95-
for padding in [0, 1, 2, "valid", "same", (1, 2)]:
96-
for dilation in [1, 2, (1, 1), (2, 3)]:
97-
if padding == "same":
98-
# ValueError: padding='same' is not supported for strided convolutions
99-
stride = 1
100-
model = torch.nn.Conv2d(2, 3, (5, 5), bias=bias, stride=stride, padding=padding, dilation=dilation)
101-
102-
input_batch = 0.1 * random.normal(random.PRNGKey(123), (7, 2, 16, 16))
103-
params = {k: 0.1 * random.normal(random.PRNGKey(123), v.shape) for k, v in model.named_parameters()}
104-
105-
model.load_state_dict({k: j2t(v) for k, v in params.items()})
106-
res_torch = model(j2t(input_batch))
107-
108-
jaxified_module = t2j(model)
109-
res_jax = jaxified_module(input_batch, state_dict=params)
110-
res_jax_jit = jit(jaxified_module)(input_batch, state_dict=params)
111-
112-
# Test forward pass with and without jax.jit
113-
aac(res_jax, res_torch.numpy(force=True), atol=1e-5)
114-
aac(res_jax_jit, res_torch.numpy(force=True), atol=1e-5)
115-
116-
# Test gradients
117-
jax_grad = grad(lambda p: (jaxified_module(input_batch, state_dict=p) ** 2).sum())(params)
118-
119-
res_torch.pow(2).sum().backward()
120-
aac(jax_grad["weight"], model.weight.grad, atol=1e-4)
121-
if bias:
122-
aac(jax_grad["bias"], model.bias.grad, atol=1e-3)
93+
for out_channels in [4, 6, 8]:
94+
for bias in [False, True]:
95+
for stride in [1, (1, 1), 2, (2, 2), (1, 3), (3, 1)]:
96+
for padding in [0, 1, 2, "valid", "same", (1, 2)]:
97+
for dilation in [1, 2, (1, 1), (2, 3)]:
98+
for groups in [1, 2]:
99+
if padding == "same":
100+
# ValueError: padding='same' is not supported for strided convolutions
101+
stride = 1
102+
model = torch.nn.Conv2d(
103+
2, out_channels, (5, 5), bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups
104+
)
105+
106+
input_batch = 0.1 * random.normal(random.PRNGKey(123), (7, 2, 16, 16))
107+
params = {k: 0.1 * random.normal(random.PRNGKey(123), v.shape) for k, v in model.named_parameters()}
108+
109+
model.load_state_dict({k: j2t(v) for k, v in params.items()})
110+
res_torch = model(j2t(input_batch))
111+
112+
jaxified_module = t2j(model)
113+
res_jax = jaxified_module(input_batch, state_dict=params)
114+
res_jax_jit = jit(jaxified_module)(input_batch, state_dict=params)
115+
116+
# Test forward pass with and without jax.jit
117+
aac(res_jax, res_torch.numpy(force=True), atol=1e-5)
118+
aac(res_jax_jit, res_torch.numpy(force=True), atol=1e-5)
119+
120+
# Test gradients
121+
jax_grad = grad(lambda p: (jaxified_module(input_batch, state_dict=p) ** 2).sum())(params)
122+
123+
res_torch.pow(2).sum().backward()
124+
aac(jax_grad["weight"], model.weight.grad, atol=1e-4)
125+
if bias:
126+
aac(jax_grad["bias"], model.bias.grad, atol=1e-3)
123127

124128

125129
def test_torch_nn_ConvTranspose2d():

torch2jax/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,6 @@ def conv2d(
571571
dilation=1,
572572
groups=1,
573573
):
574-
assert groups == 1, "conv2d with groups != 1 is not yet supported"
575-
576574
# jax.lax.conv_general_dilated supports different lo/hi padding, whereas PyTorch applies the same padding on both
577575
# sides. Note that we can't use the same trick as in conv_transpose2d since we also have to support "valid" and "same"
578576
# values for `padding`.
@@ -586,6 +584,7 @@ def conv2d(
586584
window_strides=stride,
587585
padding=padding,
588586
rhs_dilation=dilation,
587+
feature_group_count=groups,
589588
)
590589
if bias is not None:
591590
res += _v(bias)[jnp.newaxis, :, jnp.newaxis, jnp.newaxis]

0 commit comments

Comments
 (0)