diff --git a/redis/commands/vectorset/__init__.py b/redis/commands/vectorset/__init__.py index d78580a73b..590f238370 100644 --- a/redis/commands/vectorset/__init__.py +++ b/redis/commands/vectorset/__init__.py @@ -24,12 +24,12 @@ def __init__(self, client, **kwargs): # Set the module commands' callbacks self._MODULE_CALLBACKS = { VEMB_CMD: parse_vemb_result, + VSIM_CMD: parse_vsim_result, VGETATTR_CMD: lambda r: r and json.loads(r) or None, } self._RESP2_MODULE_CALLBACKS = { VINFO_CMD: lambda r: r and pairs_to_dict(r) or None, - VSIM_CMD: parse_vsim_result, VLINKS_CMD: parse_vlinks_result, } self._RESP3_MODULE_CALLBACKS = {} diff --git a/redis/commands/vectorset/commands.py b/redis/commands/vectorset/commands.py index 6123b3e7f5..4ef769d71f 100644 --- a/redis/commands/vectorset/commands.py +++ b/redis/commands/vectorset/commands.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import Awaitable, Dict, List, Optional, Union +from typing import Any, Awaitable, Dict, List, Optional, Union from redis.client import NEVER_DECODE from redis.commands.helpers import get_protocol_version @@ -19,6 +19,15 @@ VGETATTR_CMD = "VGETATTR" VRANDMEMBER_CMD = "VRANDMEMBER" +# Return type for vsim command +VSimResult = Optional[ + List[ + Union[ + List[EncodableT], Dict[EncodableT, Number], Dict[EncodableT, Dict[str, Any]] + ] + ] +] + class QuantizationOptions(Enum): """Quantization options for the VADD command.""" @@ -33,6 +42,7 @@ class CallbacksOptions(Enum): RAW = "RAW" WITHSCORES = "WITHSCORES" + WITHATTRIBS = "WITHATTRIBS" ALLOW_DECODING = "ALLOW_DECODING" RESP3 = "RESP3" @@ -123,6 +133,7 @@ def vsim( key: KeyT, input: Union[List[float], bytes, str], with_scores: Optional[bool] = False, + with_attribs: Optional[bool] = False, count: Optional[int] = None, ef: Optional[Number] = None, filter: Optional[str] = None, @@ -130,15 +141,14 @@ def vsim( truth: Optional[bool] = False, no_thread: Optional[bool] = False, epsilon: Optional[Number] = None, - ) -> Union[ - Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]], - Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]], - ]: + ) -> Union[Awaitable[VSimResult], VSimResult]: """ Compare a vector or element ``input`` with the other vectors in a vector set ``key``. - ``with_scores`` sets if the results should be returned with the - similarity scores of the elements in the result. + ``with_scores`` sets if similarity scores should be returned for each element in the result. + + ``with_attribs`` ``with_attribs`` sets if the results should be returned with the + attributes of the elements in the result, or None when no attributes are present. ``count`` sets the number of results to return. @@ -173,9 +183,17 @@ def vsim( else: pieces.extend(["ELE", input]) - if with_scores: - pieces.append("WITHSCORES") - options[CallbacksOptions.WITHSCORES.value] = True + if with_scores or with_attribs: + if get_protocol_version(self.client) in ["3", 3]: + options[CallbacksOptions.RESP3.value] = True + + if with_scores: + pieces.append("WITHSCORES") + options[CallbacksOptions.WITHSCORES.value] = True + + if with_attribs: + pieces.append("WITHATTRIBS") + options[CallbacksOptions.WITHATTRIBS.value] = True if count: pieces.extend(["COUNT", count]) diff --git a/redis/commands/vectorset/utils.py b/redis/commands/vectorset/utils.py index ed6d194ae0..61ad1eba51 100644 --- a/redis/commands/vectorset/utils.py +++ b/redis/commands/vectorset/utils.py @@ -1,3 +1,5 @@ +import json + from redis._parsers.helpers import pairs_to_dict from redis.commands.vectorset.commands import CallbacksOptions @@ -75,19 +77,53 @@ def parse_vsim_result(response, **options): structures depending on input options. Parsing VSIM result into: - List[List[str]] - - List[Dict[str, Number]] + - List[Dict[str, Number]] - when with_scores is used (without attributes) + - List[Dict[str, Mapping[str, Any]]] - when with_attribs is used (without scores) + - List[Dict[str, Union[Number, Mapping[str, Any]]]] - when with_scores and with_attribs are used + """ if response is None: return response - if options.get(CallbacksOptions.WITHSCORES.value): + withscores = bool(options.get(CallbacksOptions.WITHSCORES.value)) + withattribs = bool(options.get(CallbacksOptions.WITHATTRIBS.value)) + + # Exactly one of withscores or withattribs is True + if (withscores and not withattribs) or (not withscores and withattribs): # Redis will return a list of list of pairs. # This list have to be transformed to dict result_dict = {} - for key, value in pairs_to_dict(response).items(): - value = float(value) + if options.get(CallbacksOptions.RESP3.value): + resp_dict = response + else: + resp_dict = pairs_to_dict(response) + for key, value in resp_dict.items(): + if withscores: + value = float(value) + else: + value = json.loads(value) if value else None + result_dict[key] = value return result_dict + elif withscores and withattribs: + it = iter(response) + result_dict = {} + if options.get(CallbacksOptions.RESP3.value): + for elem, data in response.items(): + if data[1] is not None: + attribs_dict = json.loads(data[1]) + else: + attribs_dict = None + result_dict[elem] = {"score": data[0], "attributes": attribs_dict} + else: + for elem, score, attribs in zip(it, it, it): + if attribs is not None: + attribs_dict = json.loads(attribs) + else: + attribs_dict = None + + result_dict[elem] = {"score": float(score), "attributes": attribs_dict} + return result_dict else: # return the list of elements for each level # list of lists diff --git a/tests/test_asyncio/test_vsets.py b/tests/test_asyncio/test_vsets.py index 8294d8aff4..e2b2bc1d4f 100644 --- a/tests/test_asyncio/test_vsets.py +++ b/tests/test_asyncio/test_vsets.py @@ -262,6 +262,80 @@ async def test_vsim_with_scores(d_client): assert 0 <= vsim["elem1"] <= 1 +@skip_if_server_version_lt("8.2.0") +async def test_vsim_with_attribs_attribs_set(d_client): + elements_count = 5 + vector_dim = 10 + attrs_dict = {"key1": "value1", "key2": "value2"} + for i in range(elements_count): + float_array = [random.uniform(0, 5) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + attributes=attrs_dict if i % 2 == 0 else None, + ) + + vsim = await d_client.vset().vsim("myset", input="elem1", with_attribs=True) + assert len(vsim) == 5 + assert isinstance(vsim, dict) + assert vsim["elem1"] is None + assert vsim["elem2"] == attrs_dict + + +@skip_if_server_version_lt("8.2.0") +async def test_vsim_with_scores_and_attribs_attribs_set(d_client): + elements_count = 5 + vector_dim = 10 + attrs_dict = {"key1": "value1", "key2": "value2"} + for i in range(elements_count): + float_array = [random.uniform(0, 5) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + attributes=attrs_dict if i % 2 == 0 else None, + ) + + vsim = await d_client.vset().vsim( + "myset", input="elem1", with_scores=True, with_attribs=True + ) + assert len(vsim) == 5 + assert isinstance(vsim, dict) + assert isinstance(vsim["elem1"], dict) + assert "score" in vsim["elem1"] + assert "attributes" in vsim["elem1"] + assert isinstance(vsim["elem1"]["score"], float) + assert vsim["elem1"]["attributes"] is None + + assert isinstance(vsim["elem2"], dict) + assert "score" in vsim["elem2"] + assert "attributes" in vsim["elem2"] + assert isinstance(vsim["elem2"]["score"], float) + assert vsim["elem2"]["attributes"] == attrs_dict + + +@skip_if_server_version_lt("8.2.0") +async def test_vsim_with_attribs_attribs_not_set(d_client): + elements_count = 20 + vector_dim = 50 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = await d_client.vset().vsim("myset", input="elem1", with_attribs=True) + assert len(vsim) == 10 + assert isinstance(vsim, dict) + assert vsim["elem1"] is None + + @skip_if_server_version_lt("7.9.0") async def test_vsim_with_different_vector_input_types(d_client): elements_count = 10 @@ -785,13 +859,51 @@ async def test_vrandmember(d_client): assert members_list == [] +@skip_if_server_version_lt("8.2.0") +async def test_8_2_new_vset_features_without_decoding_responces(client): + # test vadd + elements = ["elem1", "elem2", "elem3"] + attrs_dict = {"key1": "value1", "key2": "value2"} + for elem in elements: + float_array = [random.uniform(0.5, 10) for x in range(0, 8)] + resp = await client.vset().vadd( + "myset", float_array, element=elem, attributes=attrs_dict + ) + assert resp == 1 + + # test vsim with attributes + vsim_with_attribs = await client.vset().vsim( + "myset", input="elem1", with_attribs=True + ) + assert len(vsim_with_attribs) == 3 + assert isinstance(vsim_with_attribs, dict) + assert isinstance(vsim_with_attribs[b"elem1"], dict) + assert vsim_with_attribs[b"elem1"] == attrs_dict + + # test vsim with score and attributes + vsim_with_scores_and_attribs = await client.vset().vsim( + "myset", input="elem1", with_scores=True, with_attribs=True + ) + assert len(vsim_with_scores_and_attribs) == 3 + assert isinstance(vsim_with_scores_and_attribs, dict) + assert isinstance(vsim_with_scores_and_attribs[b"elem1"], dict) + assert "score" in vsim_with_scores_and_attribs[b"elem1"] + assert "attributes" in vsim_with_scores_and_attribs[b"elem1"] + assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["score"], float) + assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["attributes"], dict) + assert vsim_with_scores_and_attribs[b"elem1"]["attributes"] == attrs_dict + + @skip_if_server_version_lt("7.9.0") async def test_vset_commands_without_decoding_responces(client): # test vadd elements = ["elem1", "elem2", "elem3"] + attrs_dict = {"key1": "value1", "key2": "value2"} for elem in elements: float_array = [random.uniform(0.5, 10) for x in range(0, 8)] - resp = await client.vset().vadd("myset", float_array, element=elem) + resp = await client.vset().vadd( + "myset", float_array, element=elem, attributes=attrs_dict + ) assert resp == 1 # test vemb diff --git a/tests/test_vsets.py b/tests/test_vsets.py index e212b1b286..cba7115bf1 100644 --- a/tests/test_vsets.py +++ b/tests/test_vsets.py @@ -264,6 +264,80 @@ def test_vsim_with_scores(d_client): assert 0 <= vsim["elem1"] <= 1 +@skip_if_server_version_lt("8.2.0") +def test_vsim_with_attribs_attribs_set(d_client): + elements_count = 5 + vector_dim = 10 + attrs_dict = {"key1": "value1", "key2": "value2"} + for i in range(elements_count): + float_array = [random.uniform(0, 5) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + attributes=attrs_dict if i % 2 == 0 else None, + ) + + vsim = d_client.vset().vsim("myset", input="elem1", with_attribs=True) + assert len(vsim) == 5 + assert isinstance(vsim, dict) + assert vsim["elem1"] is None + assert vsim["elem2"] == attrs_dict + + +@skip_if_server_version_lt("8.2.0") +def test_vsim_with_scores_and_attribs_attribs_set(d_client): + elements_count = 5 + vector_dim = 10 + attrs_dict = {"key1": "value1", "key2": "value2"} + for i in range(elements_count): + float_array = [random.uniform(0, 5) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + attributes=attrs_dict if i % 2 == 0 else None, + ) + + vsim = d_client.vset().vsim( + "myset", input="elem1", with_scores=True, with_attribs=True + ) + assert len(vsim) == 5 + assert isinstance(vsim, dict) + assert isinstance(vsim["elem1"], dict) + assert "score" in vsim["elem1"] + assert "attributes" in vsim["elem1"] + assert isinstance(vsim["elem1"]["score"], float) + assert vsim["elem1"]["attributes"] is None + + assert isinstance(vsim["elem2"], dict) + assert "score" in vsim["elem2"] + assert "attributes" in vsim["elem2"] + assert isinstance(vsim["elem2"]["score"], float) + assert vsim["elem2"]["attributes"] == attrs_dict + + +@skip_if_server_version_lt("8.2.0") +def test_vsim_with_attribs_attribs_not_set(d_client): + elements_count = 20 + vector_dim = 50 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = d_client.vset().vsim("myset", input="elem1", with_attribs=True) + assert len(vsim) == 10 + assert isinstance(vsim, dict) + assert vsim["elem1"] is None + + @skip_if_server_version_lt("7.9.0") def test_vsim_with_different_vector_input_types(d_client): elements_count = 10 @@ -785,13 +859,49 @@ def test_vrandmember(d_client): assert members_list == [] +@skip_if_server_version_lt("8.2.0") +def test_8_2_new_vset_features_without_decoding_responces(client): + # test vadd + elements = ["elem1", "elem2", "elem3"] + attrs_dict = {"key1": "value1", "key2": "value2"} + for elem in elements: + float_array = [random.uniform(0.5, 10) for x in range(0, 8)] + resp = client.vset().vadd( + "myset", float_array, element=elem, attributes=attrs_dict + ) + assert resp == 1 + + # test vsim with attributes + vsim_with_attribs = client.vset().vsim("myset", input="elem1", with_attribs=True) + assert len(vsim_with_attribs) == 3 + assert isinstance(vsim_with_attribs, dict) + assert isinstance(vsim_with_attribs[b"elem1"], dict) + assert vsim_with_attribs[b"elem1"] == attrs_dict + + # test vsim with score and attributes + vsim_with_scores_and_attribs = client.vset().vsim( + "myset", input="elem1", with_scores=True, with_attribs=True + ) + assert len(vsim_with_scores_and_attribs) == 3 + assert isinstance(vsim_with_scores_and_attribs, dict) + assert isinstance(vsim_with_scores_and_attribs[b"elem1"], dict) + assert "score" in vsim_with_scores_and_attribs[b"elem1"] + assert "attributes" in vsim_with_scores_and_attribs[b"elem1"] + assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["score"], float) + assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["attributes"], dict) + assert vsim_with_scores_and_attribs[b"elem1"]["attributes"] == attrs_dict + + @skip_if_server_version_lt("7.9.0") def test_vset_commands_without_decoding_responces(client): # test vadd elements = ["elem1", "elem2", "elem3"] + attrs_dict = {"key1": "value1", "key2": "value2"} for elem in elements: float_array = [random.uniform(0.5, 10) for x in range(0, 8)] - resp = client.vset().vadd("myset", float_array, element=elem) + resp = client.vset().vadd( + "myset", float_array, element=elem, attributes=attrs_dict + ) assert resp == 1 # test vemb