Skip to content

Commit d26e2ed

Browse files
seanpmorganWindQAQ
andauthored
Image/fix sparse image warp unknown batch size - r0.12 (#2311)
* Fix sparse_image_warp unknown batch size * More tests Co-authored-by: Tzu-Wei Sung <[email protected]>
1 parent 6f44559 commit d26e2ed

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

tensorflow_addons/image/sparse_image_warp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _add_zero_flow_controls_at_boundary(
8585
merged_control_point_flows: augmented set of control point flows.
8686
"""
8787

88-
batch_size = tf.compat.dimension_value(control_point_locations.shape[0])
88+
batch_size = tf.shape(control_point_locations)[0]
8989

9090
boundary_point_locations = _get_boundary_locations(
9191
image_height, image_width, boundary_points_per_edge

tensorflow_addons/image/tests/sparse_image_warp_test.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Tests for sparse_image_warp."""
1616

17+
from collections import namedtuple
18+
1719
import numpy as np
1820
import pytest
1921
import tensorflow as tf
@@ -247,23 +249,53 @@ def test_that_backprop_runs():
247249
assert np.sum(np.abs(gradients)) != 0
248250

249251

252+
ShapeConfig = namedtuple(
253+
"ShapeConfig",
254+
[
255+
"image",
256+
"source_control_point_locations",
257+
"dest_control_point_locations",
258+
"input",
259+
],
260+
)
261+
262+
250263
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
251-
@pytest.mark.parametrize("shape", [(9, 12), (9, 12, 3), (1, 9, 12, 3)])
264+
@pytest.mark.parametrize(
265+
"shape",
266+
[
267+
ShapeConfig(None, None, None, (1, 9, 12, 3)),
268+
ShapeConfig(None, [1, 1, 2], [1, 1, 2], (9, 12)),
269+
ShapeConfig(None, [1, 1, 2], [1, 1, 2], (9, 12, 3)),
270+
ShapeConfig(None, [1, 1, 2], [1, 1, 2], (1, 9, 12, 3)),
271+
ShapeConfig([None, 9, 12, 3], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3)),
272+
ShapeConfig([None, None, None, 3], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3)),
273+
ShapeConfig(
274+
[None, None, None, None], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3)
275+
),
276+
],
277+
)
252278
@pytest.mark.parametrize("interpolation_order", [1, 2, 3])
253279
@pytest.mark.parametrize("num_boundary_points", [1, 2, 3])
254-
def test_unknown_shape(shape, interpolation_order, num_boundary_points):
280+
def test_partially_or_fully_unknown_shape(
281+
shape, interpolation_order, num_boundary_points
282+
):
255283
control_point_locations = np.asarray([3.0, 3.0]).reshape(1, 1, 2).astype(np.float32)
256284
control_point_displacements = (
257285
np.asarray([0.25, -0.5]).reshape(1, 1, 2).astype(np.float32)
258286
)
259287
fn = tf.function(sparse_image_warp).get_concrete_function(
260-
image=tf.TensorSpec(shape=None, dtype=tf.float32),
261-
source_control_point_locations=tf.TensorSpec(shape=[1, 1, 2], dtype=tf.float32),
262-
dest_control_point_locations=tf.TensorSpec(shape=[1, 1, 2], dtype=tf.float32),
288+
image=tf.TensorSpec(shape=shape.image, dtype=tf.float32),
289+
source_control_point_locations=tf.TensorSpec(
290+
shape=shape.source_control_point_locations, dtype=tf.float32
291+
),
292+
dest_control_point_locations=tf.TensorSpec(
293+
shape=shape.dest_control_point_locations, dtype=tf.float32
294+
),
263295
interpolation_order=interpolation_order,
264296
num_boundary_points=num_boundary_points,
265297
)
266-
image = tf.ones(shape=shape, dtype=tf.float32)
298+
image = tf.ones(shape=shape.input, dtype=tf.float32)
267299
expected_output = sparse_image_warp(
268300
image,
269301
control_point_locations,

0 commit comments

Comments
 (0)