diff --git a/redis/commands/core.py b/redis/commands/core.py index 2613e3f8a0..73167bceb8 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -7469,12 +7469,26 @@ def zinter( ) -> ZSetRangeResponse | Awaitable[ZSetRangeResponse]: """ Return the intersect of multiple sorted sets specified by ``keys``. + With the ``aggregate`` option, it is possible to specify how the - results of the union are aggregated. This option defaults to SUM, - where the score of an element is summed across the inputs where it - exists. When this option is set to either MIN or MAX, the resulting - set will contain the minimum or maximum score of an element across - the inputs where it exists. + results of the intersection are aggregated. Available aggregation + modes: + + - ``SUM`` (default): the score of an element is summed across the + inputs where it exists. + Score = SUM(score₁×weight₁, score₂×weight₂, ...) + - ``MIN``: the resulting set will contain the minimum score of an + element across the inputs where it exists. + Score = MIN(score₁×weight₁, score₂×weight₂, ...) + - ``MAX``: the resulting set will contain the maximum score of an + element across the inputs where it exists. + Score = MAX(score₁×weight₁, score₂×weight₂, ...) + - ``COUNT``: ignores the original scores and counts weighted set + membership. Each element's score is the sum of the weights of + the input sets that contain it. + Score = SUM(weight₁, weight₂, ...) for sets containing the element. + When all weights are 1 (default), the score equals the number + of input sets containing the element. For more information, see https://redis.io/commands/zinter """ @@ -7505,11 +7519,25 @@ def zinterstore( """ Intersect multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be aggregated - based on the ``aggregate``. This option defaults to SUM, where the - score of an element is summed across the inputs where it exists. - When this option is set to either MIN or MAX, the resulting set will - contain the minimum or maximum score of an element across the inputs - where it exists. + based on the ``aggregate``. + + Available aggregation modes: + + - ``SUM`` (default): the score of an element is summed across the + inputs where it exists. + Score = SUM(score₁×weight₁, score₂×weight₂, ...) + - ``MIN``: the resulting set will contain the minimum score of an + element across the inputs where it exists. + Score = MIN(score₁×weight₁, score₂×weight₂, ...) + - ``MAX``: the resulting set will contain the maximum score of an + element across the inputs where it exists. + Score = MAX(score₁×weight₁, score₂×weight₂, ...) + - ``COUNT``: ignores the original scores and counts weighted set + membership. Each element's score is the sum of the weights of + the input sets that contain it. + Score = SUM(weight₁, weight₂, ...) for sets containing the element. + When all weights are 1 (default), the score equals the number + of input sets containing the element. For more information, see https://redis.io/commands/zinterstore """ @@ -8490,6 +8518,24 @@ def zunion( Scores will be aggregated based on the ``aggregate``, or SUM if none is provided. + Available aggregation modes: + + - ``SUM`` (default): the score of an element is summed across the + inputs where it exists. + Score = SUM(score₁×weight₁, score₂×weight₂, ...) + - ``MIN``: the resulting set will contain the minimum score of an + element across the inputs where it exists. + Score = MIN(score₁×weight₁, score₂×weight₂, ...) + - ``MAX``: the resulting set will contain the maximum score of an + element across the inputs where it exists. + Score = MAX(score₁×weight₁, score₂×weight₂, ...) + - ``COUNT``: ignores the original scores and counts weighted set + membership. Each element's score is the sum of the weights of + the input sets that contain it. + Score = SUM(weight₁, weight₂, ...) for sets containing the element. + When all weights are 1 (default), the score equals the number + of input sets containing the element. + ``score_cast_func`` a callable used to cast the score return value For more information, see https://redis.io/commands/zunion @@ -8530,6 +8576,24 @@ def zunionstore( a new sorted set, ``dest``. Scores in the destination will be aggregated based on the ``aggregate``, or SUM if none is provided. + Available aggregation modes: + + - ``SUM`` (default): the score of an element is summed across the + inputs where it exists. + Score = SUM(score₁×weight₁, score₂×weight₂, ...) + - ``MIN``: the resulting set will contain the minimum score of an + element across the inputs where it exists. + Score = MIN(score₁×weight₁, score₂×weight₂, ...) + - ``MAX``: the resulting set will contain the maximum score of an + element across the inputs where it exists. + Score = MAX(score₁×weight₁, score₂×weight₂, ...) + - ``COUNT``: ignores the original scores and counts weighted set + membership. Each element's score is the sum of the weights of + the input sets that contain it. + Score = SUM(weight₁, weight₂, ...) for sets containing the element. + When all weights are 1 (default), the score equals the number + of input sets containing the element. + For more information, see https://redis.io/commands/zunionstore """ return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) @@ -8583,11 +8647,11 @@ def _zaggregate( pieces.append(b"WEIGHTS") pieces.extend(weights) if aggregate: - if aggregate.upper() in ["SUM", "MIN", "MAX"]: + if aggregate.upper() in ["SUM", "MIN", "MAX", "COUNT"]: pieces.append(b"AGGREGATE") pieces.append(aggregate) else: - raise DataError("aggregate can be sum, min or max.") + raise DataError("aggregate can be sum, min, max or count.") if options.get("withscores", False): pieces.append(b"WITHSCORES") options["keys"] = keys diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 0bf663800b..8cd17bc815 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2010,6 +2010,40 @@ async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: [b"a1", 23.0], ] + @skip_if_server_version_lt("8.7.0") + async def test_cluster_zinterstore_count(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zinterstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="COUNT" + ) + == 2 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + [b"a1", 3.0], + [b"a3", 3.0], + ] + + @skip_if_server_version_lt("8.7.0") + async def test_cluster_zinterstore_count_with_weight(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zinterstore( + "{foo}d", + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, + aggregate="COUNT", + ) + == 2 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + [b"a1", 6.0], + [b"a3", 6.0], + ] + @skip_if_server_version_lt("4.9.0") async def test_cluster_bzpopmax(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) @@ -2182,6 +2216,44 @@ async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: [b"a1", 23.0], ] + @skip_if_server_version_lt("8.7.0") + async def test_cluster_zunionstore_count(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zunionstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="COUNT" + ) + == 4 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + [b"a4", 1.0], + [b"a2", 2.0], + [b"a1", 3.0], + [b"a3", 3.0], + ] + + @skip_if_server_version_lt("8.7.0") + async def test_cluster_zunionstore_count_with_weight(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zunionstore( + "{foo}d", + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, + aggregate="COUNT", + ) + == 4 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + [b"a2", 3.0], + [b"a4", 3.0], + [b"a1", 6.0], + [b"a3", 6.0], + ] + @skip_if_server_version_lt("2.8.9") async def test_cluster_pfcount(self, r: RedisCluster) -> None: members = {b"1", b"2", b"3"} diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index ccc1a59fa0..9fb70cb64b 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -2870,6 +2870,28 @@ async def test_zinterstore_with_weight(self, r: redis.Redis): response = await r.zrange("d", 0, -1, withscores=True) assert response == [[b"a3", 20.0], [b"a1", 23.0]] + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + async def test_zinterstore_count(self, r: redis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinterstore("d", ["a", "b", "c"], aggregate="COUNT") == 2 + response = await r.zrange("d", 0, -1, withscores=True) + assert response == [[b"a1", 3.0], [b"a3", 3.0]] + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + async def test_zinterstore_count_with_weight(self, r: redis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zinterstore("d", {"a": 1, "b": 2, "c": 3}, aggregate="COUNT") == 2 + ) + response = await r.zrange("d", 0, -1, withscores=True) + assert response == [[b"a1", 6.0], [b"a3", 6.0]] + @skip_if_server_version_lt("4.9.0") async def test_zpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -3126,6 +3148,38 @@ async def test_zunionstore_with_weight(self, r: redis.Redis): response = await r.zrange("d", 0, -1, withscores=True) assert response == [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]] + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + async def test_zunionstore_count(self, r: redis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zunionstore("d", ["a", "b", "c"], aggregate="COUNT") == 4 + response = await r.zrange("d", 0, -1, withscores=True) + assert response == [ + [b"a4", 1.0], + [b"a2", 2.0], + [b"a1", 3.0], + [b"a3", 3.0], + ] + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + async def test_zunionstore_count_with_weight(self, r: redis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zunionstore("d", {"a": 1, "b": 2, "c": 3}, aggregate="COUNT") == 4 + ) + response = await r.zrange("d", 0, -1, withscores=True) + assert response == [ + [b"a2", 3.0], + [b"a4", 3.0], + [b"a1", 6.0], + [b"a3", 6.0], + ] + # HYPERLOGLOG TESTS @skip_if_server_version_lt("2.8.9") async def test_pfadd(self, r: redis.Redis): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index b8537aad8a..da78aa6977 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2130,6 +2130,38 @@ def test_cluster_zinterstore_with_weight(self, r): [b"a1", 23.0], ] + @skip_if_server_version_lt("8.7.0") + def test_cluster_zinterstore_count(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="COUNT") + == 2 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + [b"a1", 3.0], + [b"a3", 3.0], + ] + + @skip_if_server_version_lt("8.7.0") + def test_cluster_zinterstore_count_with_weight(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zinterstore( + "{foo}d", + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, + aggregate="COUNT", + ) + == 2 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + [b"a1", 6.0], + [b"a3", 6.0], + ] + @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmax(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2}) @@ -2262,6 +2294,42 @@ def test_cluster_zunionstore_with_weight(self, r): [b"a1", 23.0], ] + @skip_if_server_version_lt("8.7.0") + def test_cluster_zunionstore_count(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="COUNT") + == 4 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + [b"a4", 1.0], + [b"a2", 2.0], + [b"a1", 3.0], + [b"a3", 3.0], + ] + + @skip_if_server_version_lt("8.7.0") + def test_cluster_zunionstore_count_with_weight(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zunionstore( + "{foo}d", + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, + aggregate="COUNT", + ) + == 4 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + [b"a2", 3.0], + [b"a4", 3.0], + [b"a1", 6.0], + [b"a3", 6.0], + ] + @skip_if_server_version_lt("2.8.9") def test_cluster_pfcount(self, r): members = {b"1", b"2", b"3"} diff --git a/tests/test_commands.py b/tests/test_commands.py index 2525dff1b7..9260f4b07b 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -3803,6 +3803,25 @@ def test_zinter(self, r): [b"a1", 23], ] + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + def test_zinter_count(self, r): + r.zadd("a", {"a1": 1, "a2": 2, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + # aggregate with COUNT (scores ignored, counts membership) + assert r.zinter(["a", "b", "c"], aggregate="COUNT", withscores=True) == [ + [b"a1", 3], + [b"a3", 3], + ] + # COUNT with weights + assert r.zinter( + {"a": 1, "b": 2, "c": 3}, aggregate="COUNT", withscores=True + ) == [ + [b"a1", 6], + [b"a3", 6], + ] + @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_zintercard(self, r): @@ -3856,6 +3875,30 @@ def test_zinterstore_with_weight(self, r): [b"a1", 23], ] + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + def test_zinterstore_count(self, r): + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("d", ["a", "b", "c"], aggregate="COUNT") == 2 + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a1", 3], + [b"a3", 3], + ] + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + def test_zinterstore_count_with_weight(self, r): + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("d", {"a": 1, "b": 2, "c": 3}, aggregate="COUNT") == 2 + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a1", 6], + [b"a3", 6], + ] + @skip_if_server_version_lt("4.9.0") def test_zpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -3875,6 +3918,7 @@ def test_zrandemember(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) assert r.zrandmember("a") is not None assert len(r.zrandmember("a", 2)) == 2 + # with scores assert len(r.zrandmember("a", 2, withscores=True)) == 2 # without duplications @@ -4266,6 +4310,29 @@ def test_zunion(self, r): [b"a1", "9.0"], ] + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + def test_zunion_count(self, r): + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + # aggregate with COUNT (scores ignored, counts membership) + assert r.zunion(["a", "b", "c"], aggregate="COUNT", withscores=True) == [ + [b"a4", 1], + [b"a2", 2], + [b"a1", 3], + [b"a3", 3], + ] + # COUNT with weights + assert r.zunion( + {"a": 1, "b": 2, "c": 3}, aggregate="COUNT", withscores=True + ) == [ + [b"a2", 3], + [b"a4", 3], + [b"a1", 6], + [b"a3", 6], + ] + @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) @@ -4318,6 +4385,34 @@ def test_zunionstore_with_weight(self, r): [b"a1", 23], ] + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + def test_zunionstore_count(self, r): + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("d", ["a", "b", "c"], aggregate="COUNT") == 4 + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a4", 1], + [b"a2", 2], + [b"a1", 3], + [b"a3", 3], + ] + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.7.0") + def test_zunionstore_count_with_weight(self, r): + r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("d", {"a": 1, "b": 2, "c": 3}, aggregate="COUNT") == 4 + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a2", 3], + [b"a4", 3], + [b"a1", 6], + [b"a3", 6], + ] + @skip_if_server_version_lt("6.1.240") def test_zmscore(self, r): with pytest.raises(exceptions.DataError):