Skip to content

Commit e717bb6

Browse files
committed
Add transaction method to client
1 parent 45c0099 commit e717bb6

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

mockredis/client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from mockredis.clock import SystemClock
1313
from mockredis.lock import MockRedisLock
14-
from mockredis.exceptions import RedisError, ResponseError
14+
from mockredis.exceptions import RedisError, ResponseError, WatchError
1515
from mockredis.pipeline import MockRedisPipeline
1616
from mockredis.script import Script
1717
from mockredis.sortedset import SortedSet
@@ -75,6 +75,30 @@ def pipeline(self, transaction=True, shard_hint=None):
7575
"""Emulate a redis-python pipeline."""
7676
return MockRedisPipeline(self, transaction, shard_hint)
7777

78+
def transaction(self, func, *watches, **kwargs):
79+
"""
80+
Convenience method for executing the callable `func` as a transaction
81+
while watching all keys specified in `watches`. The 'func' callable
82+
should expect a single argument which is a Pipeline object.
83+
84+
Copied directly from redis-py.
85+
"""
86+
shard_hint = kwargs.pop('shard_hint', None)
87+
value_from_callable = kwargs.pop('value_from_callable', False)
88+
watch_delay = kwargs.pop('watch_delay', None)
89+
with self.pipeline(True, shard_hint) as pipe:
90+
while 1:
91+
try:
92+
if watches:
93+
pipe.watch(*watches)
94+
func_value = func(pipe)
95+
exec_value = pipe.execute()
96+
return func_value if value_from_callable else exec_value
97+
except WatchError:
98+
if watch_delay is not None and watch_delay > 0:
99+
time.sleep(watch_delay)
100+
continue
101+
78102
def watch(self, *argv, **kwargs):
79103
"""
80104
Mock does not support command buffering so watch

mockredis/tests/test_pipeline.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,30 @@ def test_pipeline_args(self):
3333
with self.redis.pipeline(transaction=False, shard_hint=None):
3434
pass
3535

36+
def test_transaction(self):
37+
self.redis["a"] = 1
38+
self.redis["b"] = 2
39+
has_run = []
40+
41+
def my_transaction(pipe):
42+
a_value = pipe.get("a")
43+
assert a_value in (b"1", b"2")
44+
b_value = pipe.get("b")
45+
assert b_value == b"2"
46+
47+
# silly run-once code... incr's "a" so WatchError should be raised
48+
# forcing this all to run again. this should incr "a" once to "2"
49+
if not has_run:
50+
self.redis.incr("a")
51+
has_run.append(True)
52+
53+
pipe.multi()
54+
pipe.set("c", int(a_value) + int(b_value))
55+
56+
result = self.redis.transaction(my_transaction, "a", "b")
57+
eq_([True], result)
58+
eq_(b"4", self.redis["c"])
59+
3660
def test_set_and_get(self):
3761
"""
3862
Pipeline execution returns the pipeline, not the intermediate value.

0 commit comments

Comments
 (0)