|
14 | 14 | # ==============================================================================
|
15 | 15 | """Tests for sparse_image_warp."""
|
16 | 16 |
|
| 17 | +from collections import namedtuple |
| 18 | + |
17 | 19 | import numpy as np
|
18 | 20 | import pytest
|
19 | 21 | import tensorflow as tf
|
@@ -247,23 +249,53 @@ def test_that_backprop_runs():
|
247 | 249 | assert np.sum(np.abs(gradients)) != 0
|
248 | 250 |
|
249 | 251 |
|
| 252 | +ShapeConfig = namedtuple( |
| 253 | + "ShapeConfig", |
| 254 | + [ |
| 255 | + "image", |
| 256 | + "source_control_point_locations", |
| 257 | + "dest_control_point_locations", |
| 258 | + "input", |
| 259 | + ], |
| 260 | +) |
| 261 | + |
| 262 | + |
250 | 263 | @pytest.mark.usefixtures("maybe_run_functions_eagerly")
|
251 |
| -@pytest.mark.parametrize("shape", [(9, 12), (9, 12, 3), (1, 9, 12, 3)]) |
| 264 | +@pytest.mark.parametrize( |
| 265 | + "shape", |
| 266 | + [ |
| 267 | + ShapeConfig(None, None, None, (1, 9, 12, 3)), |
| 268 | + ShapeConfig(None, [1, 1, 2], [1, 1, 2], (9, 12)), |
| 269 | + ShapeConfig(None, [1, 1, 2], [1, 1, 2], (9, 12, 3)), |
| 270 | + ShapeConfig(None, [1, 1, 2], [1, 1, 2], (1, 9, 12, 3)), |
| 271 | + ShapeConfig([None, 9, 12, 3], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3)), |
| 272 | + ShapeConfig([None, None, None, 3], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3)), |
| 273 | + ShapeConfig( |
| 274 | + [None, None, None, None], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3) |
| 275 | + ), |
| 276 | + ], |
| 277 | +) |
252 | 278 | @pytest.mark.parametrize("interpolation_order", [1, 2, 3])
|
253 | 279 | @pytest.mark.parametrize("num_boundary_points", [1, 2, 3])
|
254 |
| -def test_unknown_shape(shape, interpolation_order, num_boundary_points): |
| 280 | +def test_partially_or_fully_unknown_shape( |
| 281 | + shape, interpolation_order, num_boundary_points |
| 282 | +): |
255 | 283 | control_point_locations = np.asarray([3.0, 3.0]).reshape(1, 1, 2).astype(np.float32)
|
256 | 284 | control_point_displacements = (
|
257 | 285 | np.asarray([0.25, -0.5]).reshape(1, 1, 2).astype(np.float32)
|
258 | 286 | )
|
259 | 287 | fn = tf.function(sparse_image_warp).get_concrete_function(
|
260 |
| - image=tf.TensorSpec(shape=None, dtype=tf.float32), |
261 |
| - source_control_point_locations=tf.TensorSpec(shape=[1, 1, 2], dtype=tf.float32), |
262 |
| - dest_control_point_locations=tf.TensorSpec(shape=[1, 1, 2], dtype=tf.float32), |
| 288 | + image=tf.TensorSpec(shape=shape.image, dtype=tf.float32), |
| 289 | + source_control_point_locations=tf.TensorSpec( |
| 290 | + shape=shape.source_control_point_locations, dtype=tf.float32 |
| 291 | + ), |
| 292 | + dest_control_point_locations=tf.TensorSpec( |
| 293 | + shape=shape.dest_control_point_locations, dtype=tf.float32 |
| 294 | + ), |
263 | 295 | interpolation_order=interpolation_order,
|
264 | 296 | num_boundary_points=num_boundary_points,
|
265 | 297 | )
|
266 |
| - image = tf.ones(shape=shape, dtype=tf.float32) |
| 298 | + image = tf.ones(shape=shape.input, dtype=tf.float32) |
267 | 299 | expected_output = sparse_image_warp(
|
268 | 300 | image,
|
269 | 301 | control_point_locations,
|
|
0 commit comments