diff --git a/README.md b/README.md index 03a40e4..8fc39d6 100644 --- a/README.md +++ b/README.md @@ -48,12 +48,51 @@ x.backward() assert np.allclose((a.grad, b.grad), (3., 24.)) ``` +or simply wrap an existing tensorflow function + +```python +def tf_function(a, b): + c = 3 * a + 4 * b * b + + return c + +session = tf.compat.v1.Session() +f = tfpyth.wrap_torch_from_tensorflow( + tf_function, ["a", "b"], session=session + ) +# or simpler +f = tfpyth.wrap_torch_from_tensorflow( + tf_function, session=session + ) # automatically creates placeholders for "a" and "b" inside +# or even simpler +f = tfpyth.wrap_torch_from_tensorflow( + tf_function + ) # automatically creates placeholders for "a" and "b" and session + +a_ = th.tensor(1, dtype=th.float32, requires_grad=True) +b_ = th.tensor(3, dtype=th.float32, requires_grad=True) +x = f(a_, b_) + +assert x == 39.0 + +x.backward() + +assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) +``` + +* see `tests` for more examples + + ## What it's got ### `torch_from_tensorflow` Creates a PyTorch function that is differentiable by evaluating a TensorFlow output tensor given input placeholders. +### `wrap_torch_from_tensorflow` + +Wrap a TensorFlow function into a PyTorch function and automatically create placeholders + ### `eager_tensorflow_from_torch` Creates an eager Tensorflow function from a PyTorch function. @@ -62,6 +101,22 @@ Creates an eager Tensorflow function from a PyTorch function. Creates a TensorFlow op/tensor from a PyTorch function. + +## Notes on session management + + +* when using `wrap_torch_from_tensorflow` without `session` argument, a (singleton) session will be created in the background and used for every call to `wrap_torch_from_tensorflow`. +* one can access this session using + +```python +import tfpyth + +session = tfpyth.SingleSession.get_session() + +``` + + + ## Future work - [ ] support JAX diff --git a/tests/test_adapters.py b/tests/test_adapters.py index ab0f635..df75a91 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -55,3 +55,154 @@ def get_tf_function(): x.backward() assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + + +class Test_tensorflow_in_pytorch: + def test_single_output(self): + session = tf.Session() + + def get_tf_function(): + a = tf.placeholder(tf.float32, name="a") + b = tf.placeholder(tf.float32, name="b") + c = 3 * a + 4 * b * b + + f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply + return f + + f = get_tf_function() + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x = f(a_, b_) + + assert x == 39.0 + + x.backward() + + assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + + def test_multiple_outputs(self): + session = tf.Session() + + def get_tf_function(): + a = tf.placeholder(tf.float32, name="a") + b = tf.placeholder(tf.float32, name="b") + c = 3 * a + 4 * b * b + d = 6 * a + 8 * b ** 2 + + f = tfpyth.torch_from_tensorflow(session, [a, b], [c, d]) + f1, f2 = [ff.apply for ff in f] + return f1, f2 + + f1, f2 = get_tf_function() + + def f(a, b): + return f1(a, b), f2(a, b) + + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x1, x2 = f(a_, b_) + + assert x1 == 39.0 + assert x2 == 78.0 + + x1.backward() + x2.backward() + + assert np.allclose((a_.grad, b_.grad), (9.0, 72.0)) + + +class Test_wrap_torch_from_tensorflow: + def test_image_operation(self): + def tensorflow_function(a, size=(128, 128)): + return tf.image.resize(a, size=size) + + from functools import partial + + session = tf.compat.v1.Session() + tf_func = partial(tensorflow_function, size=(128, 128)) + f_pt = tfpyth.wrap_torch_from_tensorflow(tf_func, ["a"], [(None, 64, 64, 1)], session=session) + x = th.ones((1, 64, 64, 1), dtype=th.float32) + y = f_pt(x) + assert y.shape == (1, 128, 128, 1) + + def test_no_gradient_operation(self): + def tensorflow_function(a, size=(128, 128)): + return tf.image.resize(a, size=size) + + from functools import partial + + session = tf.compat.v1.Session() + tf_func = partial(tensorflow_function, size=(128, 128)) + f_pt = tfpyth.wrap_torch_from_tensorflow(tf_func, ["a"], [(None, 64, 64, 1)], session=session) + x = th.ones((1, 64, 64, 1), dtype=th.float32, requires_grad=False) + conv = th.nn.Conv2d(1, 1, 1) + x = conv(tfpyth.th_2D_channels_last_to_first(x)) + x = tfpyth.th_2D_channels_first_to_last(x) + y = f_pt(x) + + assert y.shape == (1, 128, 128, 1) + assert y.sum().backward() is None + assert conv.bias.grad + + def test_tensorflow_in_pytorch(self): + session = tf.compat.v1.Session() + + def get_tf_function(a, b): + c = 3 * a + 4 * b * b + + return c + + session = tf.compat.v1.Session() + f = tfpyth.wrap_torch_from_tensorflow(get_tf_function, ["a", "b"], None, session=session) + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x = f(a_, b_) + + assert x == 39.0 + + x.backward() + + assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + + def test_multiple_outputs(self): + session = tf.compat.v1.Session() + + def get_tf_function(a, b): + c = 3 * a + 4 * b * b + d = 6 * a + 8 * b ** 2 + + return c, d + + session = tf.compat.v1.Session() + f = tfpyth.wrap_torch_from_tensorflow(get_tf_function, ["a", "b"], None, session=session) + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x1, x2 = f(a_, b_) + + assert x1 == 39.0 + assert x2 == 78.0 + + x1.backward() + assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + x2.backward() # partial derivatives are additive + assert np.allclose((a_.grad, b_.grad), (9.0, 72.0)) + + def test_autodetect_varnames(self): + session = tf.compat.v1.Session() + + def get_tf_function(a, b): + c = 3 * a + 4 * b * b + + return c + + session = tf.compat.v1.Session() + f = tfpyth.wrap_torch_from_tensorflow(get_tf_function) + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x = f(a_, b_) + + assert x == 39.0 + + x.backward() + + assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) diff --git a/tfpyth/__init__.py b/tfpyth/__init__.py index 4fcee7f..720c21f 100644 --- a/tfpyth/__init__.py +++ b/tfpyth/__init__.py @@ -1,7 +1,19 @@ import tensorflow as tf import torch as th +import functools +class SingleSession: + instance = None + """https://python-3-patterns-idioms-test.readthedocs.io/en/latest/Singleton.html""" + + def __init__(self): + if not SingleSession.instance: + SingleSession.instance = tf.compat.v1.Session() + + def get_session(self): + return SingleSession.instance + class TensorFlowFunction(th.autograd.Function): """ Wrapper class for Tensorflow input/output nodes (incl gradient) in PyTorch. @@ -13,7 +25,7 @@ class TensorFlowFunction(th.autograd.Function): gradient_outputs = None -def torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype=tf.float32): +def torch_from_tensorflow(tf_session, tf_inputs, tf_outputs, tf_dtype=tf.float32): """ Create a PyTorch TensorFlowFunction with forward and backward methods which executes evaluates the passed TensorFlow tensors. @@ -31,42 +43,122 @@ def torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype=tf.float32) :return: TensorflowFunction which can be applied to PyTorch tensors. """ # create gradient placeholders - tf_gradient_placeholder = tf.placeholder(dtype=tf_dtype, name=f"gradient") - tf_gradient_outputs = tf.gradients( - ys=tf_output, xs=tf_inputs, grad_ys=[tf_gradient_placeholder], unconnected_gradients="zero" - ) - - class _TensorFlowFunction(TensorFlowFunction): - inputs = tf_inputs - output = tf_output - gradient_placeholder = tf_gradient_placeholder - gradient_outputs = tf_gradient_outputs - - @staticmethod - def forward(ctx, *args): - assert len(args) == len(tf_inputs) - - feed_dict = {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, args)} - output = tf_session.run(tf_output, feed_dict) - ctx.save_for_backward(*args) - - th_output = th.as_tensor(output) - return th_output - - # See https://www.janfreyberg.com/blog/2019-04-01-testing-pytorch-functions/ for why "no cover" - @staticmethod - def backward(ctx, grad_output): # pragma: no cover - th_inputs = ctx.saved_tensors - - feed_dict = {} - feed_dict.update({tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, th_inputs)}) - feed_dict.update({tf_gradient_placeholder: grad_output.detach().numpy()}) - - tf_gradients = tf_session.run(tf_gradient_outputs, feed_dict) - return tuple(th.as_tensor(tf_gradient) for tf_gradient in tf_gradients) - - return _TensorFlowFunction() + def _torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype=tf.float32): + tf_gradient_placeholder = tf.placeholder(dtype=tf_dtype, name=f"gradient") + tf_gradient_outputs = tf.gradients( + ys=tf_output, xs=tf_inputs, grad_ys=[tf_gradient_placeholder], unconnected_gradients="zero" + ) + + class _TensorFlowFunction(TensorFlowFunction): + inputs = tf_inputs + output = tf_output + gradient_placeholder = tf_gradient_placeholder + gradient_outputs = tf_gradient_outputs + + @staticmethod + def forward(ctx, *args): + assert len(args) == len(tf_inputs) + + feed_dict = {} + for tf_input, th_input in zip(tf_inputs, args): + if th_input.is_cuda: + feed_dict[tf_input] = th_input.cpu().detach().numpy() + else: + feed_dict[tf_input] = th_input.detach().numpy() + + # TODO: write test for cuda tensors + # feed_dict = {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, args)} + output = tf_session.run(tf_output, feed_dict) + + ctx.save_for_backward(*args) + + th_output = th.as_tensor(output) + return th_output + + # See https://www.janfreyberg.com/blog/2019-04-01-testing-pytorch-functions/ for why "no cover" + @staticmethod + def backward(ctx, grad_output): # pragma: no cover + th_inputs = ctx.saved_tensors + + feed_dict = {} + feed_dict.update( + {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, th_inputs)} + ) + feed_dict.update({tf_gradient_placeholder: grad_output.detach().numpy()}) + + tf_gradients = tf_session.run(tf_gradient_outputs, feed_dict) + return tuple(th.as_tensor(tf_gradient) for tf_gradient in tf_gradients) + + return _TensorFlowFunction() + + if isinstance(tf_outputs, list): + output_functions = [] + for tf_output in tf_outputs: + output_functions.append(_torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype)) + return output_functions + else: + return _torch_from_tensorflow(tf_session, tf_inputs, tf_outputs, tf_dtype) + + +def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, input_dtypes=None, session=None): + """wrap func using `torch_from_tensorflow` and automatically create placeholders. + + By default, placeholders are assumed to be `tf.float32`. + + :param func: Callable. + Tensorflow function to evaluate + :param tensor_input: List[str] + List of argument names to `func` that represent a tensor input. + if not provided will interpret all arguments from func as tensorflow placeholders. + :param input_shapes: List[Tuple[Int]]. + Shapes of input tensors if known. Some operations require these, such as all `tf.image.resize`. + Basically these values are fed to `tf.placeholder`, so you can indicate unknown parameters using `(None, 64, 64, 1)`, for instance. + :param input_dtypes: List[tf.dtype]. + Data types to associate inputs with. By default, will treat all inputs as `tf.float32` + :param session: tf.compat.v1.Session + A session. If None, will instantiate new session. + """ + if session is None: + session = SingleSession().get_session() + if tensor_inputs is None: + if isinstance(func, functools.partial): + func = func.func + tensor_inputs = func.__code__.co_varnames[: func.__code__.co_argcount] + + if input_shapes is not None: + if len(tensor_inputs) != len(input_shapes): + raise ValueError("Number of tensor inputs does not match number of input shapes") + else: + if input_dtypes is not None: + if len(input_dtypes) != len(input_shapes): + raise ValueError("Number of tensor input dtypes does not match number of input shapes") + else: + placeholders = { + arg_name: tf.compat.v1.placeholder(shape=shape, dtype=dtype, name=arg_name) + for arg_name, shape, dtype in zip(tensor_inputs, input_shapes, input_dtypes) + } + else: + placeholders = { + arg_name: tf.compat.v1.placeholder(tf.float32, shape=shape, name=arg_name) + for arg_name, shape in zip(tensor_inputs, input_shapes) + } + else: + placeholders = {arg_name: tf.compat.v1.placeholder(tf.float32, name=arg_name) for arg_name in tensor_inputs} + outputs = func(**placeholders) + + if isinstance(outputs, tuple): + fs = [ + torch_from_tensorflow(session, [placeholders[t] for t in tensor_inputs], output).apply for output in outputs + ] + + def f(*args): + return [ff(*args) for ff in fs] + + else: + output = outputs + f = torch_from_tensorflow(session, [placeholders[t] for t in tensor_inputs], output).apply + return f def eager_tensorflow_from_torch(func): @@ -106,3 +198,27 @@ def tensorflow_from_torch(func, inp, Tout, name=None): eager_compute = eager_tensorflow_from_torch(func) return tf.py_function(eager_compute, inp, Tout, name=name) + + +def tf_NCHW_to_NHWC(x): + return tf.transpose(x, (0, 2, 3, 1)) + + +def tf_NHWC_to_NCHW(x): + return tf.transpose(x, (0, 3, 1, 2)) + + +tf_2D_channels_first_to_last = tf_NCHW_to_NHWC +tf_2D_channels_last_to_first = tf_NHWC_to_NCHW + + +def th_NCHW_to_NHWC(x): + return x.permute((0, 2, 3, 1)) + + +def th_NHWC_to_NCHW(x): + return x.permute((0, 3, 1, 2)) + + +th_2D_channels_last_to_first = th_NHWC_to_NCHW +th_2D_channels_first_to_last = th_NCHW_to_NHWC