-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Add adaptive pooling (1D, 2D, 3D) support across JAX, TensorFlow, and PyTorch backends #21820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
f99cc63
f830e93
9938ef1
323a1ab
df57227
5343b71
4cc8ac0
12edcb4
248773f
53a5dc9
2727a24
2a94421
edcf848
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1464,3 +1464,368 @@ def _pair(x): | |
| # ---- reshape -> (N, C*kH*kW, L) ---- | ||
| _, CKK, oH, oW = patches.shape | ||
| return patches.reshape(N, CKK, oH * oW) | ||
|
|
||
|
|
||
| def get_static_window_sizes(input_dim, output_dim): | ||
| """Calculate small and big window sizes for adaptive pooling.""" | ||
| small_window = math.ceil(input_dim / output_dim) | ||
| big_window = small_window + 1 | ||
| return small_window, big_window | ||
|
Comment on lines
+1469
to
+1473
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some version of this is done in each backend. Can we share the code between backends? It can go in |
||
|
|
||
|
|
||
| def compute_static_gather_indices(input_dim, output_size, big_window): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename |
||
| """Compute gather indices for Two-Pool Gather method.""" | ||
| window_starts = jnp.floor( | ||
| (jnp.arange(output_size) * input_dim) / output_size | ||
| ).astype(jnp.int32) | ||
|
|
||
| window_ends = jnp.ceil( | ||
| (jnp.arange(1, output_size + 1) * input_dim) / output_size | ||
| ).astype(jnp.int32) | ||
|
|
||
| window_sizes = window_ends - window_starts | ||
| is_big_window = window_sizes == big_window | ||
|
|
||
| small_window = big_window - 1 | ||
| small_pool_len = input_dim - small_window + 1 | ||
|
|
||
| small_indices = window_starts | ||
| big_indices = window_starts + small_pool_len | ||
|
|
||
| gather_indices = jnp.where(is_big_window, big_indices, small_indices) | ||
| return gather_indices.astype(jnp.int32) | ||
|
|
||
|
|
||
| # ---------- 1D Adaptive Pooling ---------- | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove those |
||
| def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename |
||
| """Adaptive Average Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
Comment on lines
+1508
to
+1509
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable names For example: Style Guide ReferencesFootnotes
|
||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| small_pool_l = small_pool_l / small_l | ||
|
|
||
| big_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = big_pool_l / big_l | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| # ---------- 2D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- 3D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_d = small_pool_d / small_d | ||
|
|
||
| big_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_d = big_pool_d / big_d | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, small_d, 1, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_d = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, small_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, big_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, small_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, big_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- Dispatcher ---------- | ||
| def adaptive_avg_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive average pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_avg_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_avg_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_avg_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_avg_pool supports 1D, 2D, or 3D inputs only." | ||
| ) | ||
|
|
||
|
|
||
| def adaptive_max_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive max pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_max_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_max_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_max_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_max_pool supports 1D, 2D, or 3D inputs only." | ||
| ) | ||
|
Comment on lines
+1499
to
+1831
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementations for Here are a couple of suggestions:
Comment on lines
+1500
to
+1831
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementations for 1D, 2D, and 3D adaptive pooling for both Consider refactoring this by creating a generalized helper function. This function could handle the pooling logic for a single dimension and could be parameterized for average vs. max pooling. For example, you could have a helper: Then, the 2D and 3D functions can be implemented by composing this helper function for each spatial dimension. This would greatly reduce code duplication and improve maintainability. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1237,3 +1237,19 @@ def _pair(x): | |
|
|
||
| # ---- reshape -> (N, C*kH*kW, L) ---- | ||
| return patches.reshape(N, C * k[0] * k[1], -1) | ||
|
|
||
|
|
||
| def adaptive_max_pool(inputs, output_size, data_format=None): | ||
| """Adaptive max pooling - Numpy backend not yet supported.""" | ||
| raise NotImplementedError( | ||
| "Adaptive pooling not implemented for Numpy. " | ||
| "Use JAX, Torch or Tensorflow backend." | ||
| ) | ||
|
|
||
|
|
||
| def adaptive_avg_pool(inputs, output_size, data_format=None): | ||
| """Adaptive average pooling - Numpy backend not yet supported.""" | ||
| raise NotImplementedError( | ||
| "Adaptive pooling not implemented for Numpy. " | ||
| "Use JAX, Torch or Tensorflow backend." | ||
| ) | ||
|
Comment on lines
+1242
to
+1255
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any way we can have a NumPy implementation? If not, can we plug the JAX implementation? (like we did for convolutions). |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this file.