Skip to content

Commit 1c5cc83

Browse files
committed
__init__.py: fix formatting
1 parent 63afed1 commit 1c5cc83

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

torch2jax/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def poop(self) -> jax.random.PRNGKey:
2727

2828
def mk_rng() -> jax.random.PRNGKey:
2929
assert len(_RNG_POOPER_STACK) > 0, "Attempted `mk_rng()` outside of a `RngPooperContext`"
30-
assert (
31-
_RNG_POOPER_STACK[-1] is not None
32-
), "Attempted `mk_rng()` with a `None` `RngPooperContext`. You're probably seeing this error message because you forgot to include a `rng` kwarg in your function call: `t2j(f)(..., rng=jax.random.PRNGKey(0))`. "
30+
assert _RNG_POOPER_STACK[-1] is not None, (
31+
"Attempted `mk_rng()` with a `None` `RngPooperContext`. You're probably seeing this error message because you forgot to include a `rng` kwarg in your function call: `t2j(f)(..., rng=jax.random.PRNGKey(0))`. "
32+
)
3333
return _RNG_POOPER_STACK[-1].poop()
3434

3535

@@ -82,9 +82,9 @@ def j2t_array(jax_array):
8282
class Torchish:
8383
def __init__(self, value):
8484
# See https://github.com/google/jax/issues/2115 re `isinstance(value, jnp.ndarray)`.
85-
assert (
86-
isinstance(value, jnp.ndarray) or isinstance(value, int) or isinstance(value, float)
87-
), f"Tried to create Torchish with unsupported type: {type(value)}"
85+
assert isinstance(value, jnp.ndarray) or isinstance(value, int) or isinstance(value, float), (
86+
f"Tried to create Torchish with unsupported type: {type(value)}"
87+
)
8888
self.value = value
8989

9090
# In order for PyTorch to accept an object as one of its own and allow dynamic dispatch it must either subclass
@@ -120,9 +120,9 @@ def expand(self, *sizes):
120120
newshape = [new if new != -1 else old for old, new in zip(self.shape, sizes)]
121121
for i, (old, new) in enumerate(zip(self.shape, sizes)):
122122
if old != 1:
123-
assert (
124-
newshape[i] == old
125-
), f"Attempted to expand dimension {i} from {old} to {new}. Cannot expand on non-singleton dimensions."
123+
assert newshape[i] == old, (
124+
f"Attempted to expand dimension {i} from {old} to {new}. Cannot expand on non-singleton dimensions."
125+
)
126126

127127
return Torchish(jnp.broadcast_to(self.value, newshape))
128128

0 commit comments

Comments
 (0)