diff --git a/redis/_parsers/commands.py b/redis/_parsers/commands.py index b5109252ae..a7571ac195 100644 --- a/redis/_parsers/commands.py +++ b/redis/_parsers/commands.py @@ -1,11 +1,43 @@ +from dataclasses import dataclass +from enum import Enum from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union -from redis.exceptions import RedisError, ResponseError +from redis.exceptions import RedisError, ResponseError, IncorrectPolicyType from redis.utils import str_if_bytes if TYPE_CHECKING: from redis.asyncio.cluster import ClusterNode +class RequestPolicy(Enum): + ALL_NODES = 'all_nodes' + ALL_SHARDS = 'all_shards' + MULTI_SHARD = 'multi_shard' + SPECIAL = 'special' + DEFAULT_KEYLESS = 'default_keyless' + DEFAULT_KEYED = 'default_keyed' + +class ResponsePolicy(Enum): + ONE_SUCCEEDED = 'one_succeeded' + ALL_SUCCEEDED = 'all_succeeded' + AGG_LOGICAL_AND = 'agg_logical_and' + AGG_LOGICAL_OR = 'agg_logical_or' + AGG_MIN = 'agg_min' + AGG_MAX = 'agg_max' + AGG_SUM = 'agg_sum' + SPECIAL = 'special' + DEFAULT_KEYLESS = 'default_keyless' + DEFAULT_KEYED = 'default_keyed' + +class CommandPolicies: + def __init__( + self, + request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS, + response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS + ): + self.request_policy = request_policy + self.response_policy = response_policy + +PolicyRecords = dict[str, dict[str, CommandPolicies]] class AbstractCommandsParser: def _get_pubsub_keys(self, *args): @@ -64,7 +96,8 @@ class CommandsParser(AbstractCommandsParser): def __init__(self, redis_connection): self.commands = {} - self.initialize(redis_connection) + self.redis_connection = redis_connection + self.initialize(self.redis_connection) def initialize(self, r): commands = r.command() @@ -169,6 +202,173 @@ def _get_moveable_keys(self, redis_conn, *args): raise e return keys + def _is_keyless_command(self, command_name: str, subcommand_name: Optional[str]=None) -> bool: + """ + Determines whether a given command or subcommand is considered "keyless". + + A keyless command does not operate on specific keys, which is determined based + on the first key position in the command or subcommand details. If the command + or subcommand's first key position is zero or negative, it is treated as keyless. + + Parameters: + command_name: str + The name of the command to check. + subcommand_name: Optional[str], default=None + The name of the subcommand to check, if applicable. If not provided, + the check is performed only on the command. + + Returns: + bool + True if the specified command or subcommand is considered keyless, + False otherwise. + + Raises: + ValueError + If the specified subcommand is not found within the command or the + specified command does not exist in the available commands. + """ + if subcommand_name: + for subcommand in self.commands.get(command_name)['subcommands']: + if str_if_bytes(subcommand[0]) == subcommand_name: + parsed_subcmd = self.parse_subcommand(subcommand) + return parsed_subcmd['first_key_pos'] <= 0 + raise ValueError(f"Subcommand {subcommand_name} not found in command {command_name}") + else: + command_details = self.commands.get(command_name, None) + if command_details is not None: + return command_details['first_key_pos'] <= 0 + + raise ValueError(f"Command {command_name} not found in commands") + + def get_command_policies(self) -> PolicyRecords: + """ + Retrieve and process the command policies for all commands and subcommands. + + This method traverses through commands and subcommands, extracting policy details + from associated data structures and constructing a dictionary of commands with their + associated policies. It supports nested data structures and handles both main commands + and their subcommands. + + Returns: + PolicyRecords: A collection of commands and subcommands associated with their + respective policies. + + Raises: + IncorrectPolicyType: If an invalid policy type is encountered during policy extraction. + """ + command_with_policies = {} + + def extract_policies(data, module_name, command_name): + """ + Recursively extract policies from nested data structures. + + Args: + data: The data structure to search (can be list, dict, str, bytes, etc.) + command_name: The command name to associate with found policies + """ + if isinstance(data, (str, bytes)): + # Decode bytes to string if needed + policy = str_if_bytes(data.decode()) + + # Check if this is a policy string + if policy.startswith('request_policy') or policy.startswith('response_policy'): + if policy.startswith('request_policy'): + policy_type = policy.split(':')[1] + + try: + command_with_policies[module_name][command_name].request_policy = RequestPolicy(policy_type) + except ValueError: + raise IncorrectPolicyType(f"Incorrect request policy type: {policy_type}") + + if policy.startswith('response_policy'): + policy_type = policy.split(':')[1] + + try: + command_with_policies[module_name][command_name].response_policy = ResponsePolicy(policy_type) + except ValueError: + raise IncorrectPolicyType(f"Incorrect response policy type: {policy_type}") + + elif isinstance(data, list): + # For lists, recursively process each element + for item in data: + extract_policies(item, module_name, command_name) + + elif isinstance(data, dict): + # For dictionaries, recursively process each value + for value in data.values(): + extract_policies(value, module_name, command_name) + + for command, details in self.commands.items(): + # Check whether the command has keys + is_keyless = self._is_keyless_command(command) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + # Check if it's a core or module command + split_name = command.split('.') + + if len(split_name) > 1: + module_name = split_name[0] + command_name = split_name[1] + else: + module_name = 'core' + command_name = split_name[0] + + # Create a CommandPolicies object with default policies on the new command. + if command_with_policies.get(module_name, None) is None: + command_with_policies[module_name] = {command_name: CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + )} + else: + command_with_policies[module_name][command_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + ) + + tips = details.get('tips') + subcommands = details.get('subcommands') + + # Process tips for the main command + if tips: + extract_policies(tips, module_name, command_name) + + # Process subcommands + if subcommands: + for subcommand_details in subcommands: + # Get the subcommand name (first element) + subcmd_name = subcommand_details[0] + if isinstance(subcmd_name, bytes): + subcmd_name = subcmd_name.decode() + + # Check whether the subcommand has keys + is_keyless = self._is_keyless_command(command, subcmd_name) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + subcmd_name = subcmd_name.replace('|', ' ') + + # Create a CommandPolicies object with default policies on the new command. + command_with_policies[module_name][subcmd_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + ) + + # Recursively extract policies from the rest of the subcommand details + for subcommand_detail in subcommand_details[1:]: + extract_policies(subcommand_detail, module_name, subcmd_name) + + return command_with_policies class AsyncCommandsParser(AbstractCommandsParser): """ diff --git a/redis/commands/policies.py b/redis/commands/policies.py new file mode 100644 index 0000000000..a2f7f45924 --- /dev/null +++ b/redis/commands/policies.py @@ -0,0 +1,130 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from redis._parsers.commands import CommandPolicies, PolicyRecords, RequestPolicy, ResponsePolicy, CommandsParser + +STATIC_POLICIES: PolicyRecords = { + 'ft': { + 'explaincli': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'suglen': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), + 'profile': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'dropindex': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'aliasupdate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'alter': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'aggregate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'syndump': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'create': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'explain': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'sugget': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), + 'dictdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'aliasadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'dictadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'synupdate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'drop': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'info': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'sugadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), + 'dictdump': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'cursor': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'search': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'tagvals': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'aliasdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'sugdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), + 'spellcheck': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + } +} + +class PolicyResolver(ABC): + + @abstractmethod + def resolve(self, command_name: str) -> CommandPolicies: + """ + Resolves the command name and determines the associated command policies. + + Args: + command_name: The name of the command to resolve. + + Returns: + CommandPolicies: The policies associated with the specified command. + """ + pass + + @abstractmethod + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + """ + Factory method to instantiate a policy resolver with a fallback resolver. + + Args: + fallback: Fallback resolver + + Returns: + PolicyResolver: Returns a new policy resolver with the specified fallback resolver. + """ + pass + +class BasePolicyResolver(PolicyResolver): + """ + Base class for policy resolvers. + """ + def __init__(self, policies: PolicyRecords, fallback: Optional[PolicyResolver] = None) -> None: + self._policies = policies + self._fallback = fallback + + def resolve(self, command_name: str) -> CommandPolicies: + parts = command_name.split(".") + + if len(parts) > 2: + raise ValueError(f"Wrong command or module name: {command_name}") + + module, command = parts if len(parts) == 2 else ("core", parts[0]) + + if self._policies.get(module, None) is None: + if self._fallback is not None: + return self._fallback.resolve(command_name) + else: + raise ValueError(f"Module {module} not found") + + if self._policies.get(module).get(command, None) is None: + if self._fallback is not None: + return self._fallback.resolve(command_name) + else: + raise ValueError(f"Command {command} not found in module {module}") + + return self._policies.get(module).get(command) + + @abstractmethod + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + pass + + +class DynamicPolicyResolver(BasePolicyResolver): + """ + Resolves policy dynamically based on the COMMAND output. + """ + def __init__(self, commands_parser: CommandsParser, fallback: Optional[PolicyResolver] = None) -> None: + """ + Parameters: + commands_parser (CommandsParser): COMMAND output parser. + fallback (Optional[PolicyResolver]): An optional resolver to be used when the + primary policies cannot handle a specific request. + """ + self._commands_parser = commands_parser + super().__init__(commands_parser.get_command_policies(), fallback) + + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + return DynamicPolicyResolver(self._commands_parser, fallback) + + +class StaticPolicyResolver(BasePolicyResolver): + """ + Resolves policy from a static list of policy records. + """ + def __init__(self, fallback: Optional[PolicyResolver] = None) -> None: + """ + Parameters: + fallback (Optional[PolicyResolver]): An optional fallback policy resolver + used for resolving policies if static policies are inadequate. + """ + super().__init__(STATIC_POLICIES, fallback) + + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + return StaticPolicyResolver(fallback) \ No newline at end of file diff --git a/redis/exceptions.py b/redis/exceptions.py index 643444986b..458ba5843f 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -245,3 +245,9 @@ class InvalidPipelineStack(RedisClusterException): """ pass + +class IncorrectPolicyType(Exception): + """ + Raised when a policy type isn't matching to any known policy types. + """ + pass \ No newline at end of file diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index e3b44a147f..6be43e5823 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,5 +1,8 @@ +from pprint import pprint + import pytest from redis._parsers import CommandsParser +from redis._parsers.commands import RequestPolicy, ResponsePolicy from .conftest import ( assert_resp_response, @@ -106,3 +109,63 @@ def test_get_pubsub_keys(self, r): assert commands_parser.get_keys(r, *args2) == ["foo1", "foo2", "foo3"] assert commands_parser.get_keys(r, *args3) == ["*"] assert commands_parser.get_keys(r, *args4) == ["foo1", "foo2", "foo3"] + + @skip_if_server_version_lt("7.0.0") + @pytest.mark.onlycluster + def test_get_command_policies(self, r): + commands_parser = CommandsParser(r) + expected_command_policies = { + 'core': { + 'keys': ['keys', RequestPolicy.ALL_SHARDS, ResponsePolicy.DEFAULT_KEYLESS], + 'acl setuser': ['acl setuser', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], + 'exists': ['exists', RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], + 'config resetstat': ['config resetstat', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], + 'slowlog len': ['slowlog len', RequestPolicy.ALL_NODES, ResponsePolicy.AGG_SUM], + 'scan': ['scan', RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], + 'latency history': ['latency history', RequestPolicy.ALL_NODES, ResponsePolicy.SPECIAL], + 'memory doctor': ['memory doctor', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], + 'randomkey': ['randomkey', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], + 'mget': ['mget', RequestPolicy.MULTI_SHARD, ResponsePolicy.DEFAULT_KEYED], + 'function restore': ['function restore', RequestPolicy.ALL_SHARDS, ResponsePolicy.ALL_SUCCEEDED], + }, + 'json': { + 'debug': ['debug', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'get': ['get', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'ft': { + 'search': ['search', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], + 'create': ['create', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], + }, + 'bf': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'madd': ['madd', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'cf': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'mexists': ['mexists', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'tdigest': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'min': ['min', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'ts': { + 'create': ['create', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'info': ['info', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'topk': { + 'list': ['list', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'query': ['query', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + } + } + + actual_policies = commands_parser.get_command_policies() + assert len(actual_policies) > 0 + + for module_name, commands in expected_command_policies.items(): + for command, command_policies in commands.items(): + assert command in actual_policies[module_name] + assert command_policies == [ + command, + actual_policies[module_name][command].request_policy, + actual_policies[module_name][command].response_policy + ] \ No newline at end of file diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py new file mode 100644 index 0000000000..c0d057f0b0 --- /dev/null +++ b/tests/test_command_policies.py @@ -0,0 +1,57 @@ +from unittest.mock import Mock + +import pytest + +from redis._parsers import CommandsParser +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy +from redis.commands.policies import DynamicPolicyResolver, StaticPolicyResolver + + +@pytest.mark.onlycluster +class TestBasePolicyResolver: + def test_resolve(self): + mock_command_parser = Mock(spec=CommandsParser) + zcount_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) + rpoplpush_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) + + mock_command_parser.get_command_policies.return_value = { + 'core': { + 'zcount': zcount_policy, + 'rpoplpush': rpoplpush_policy, + } + } + + dynamic_resolver = DynamicPolicyResolver(mock_command_parser) + assert dynamic_resolver.resolve('zcount') == zcount_policy + assert dynamic_resolver.resolve('rpoplpush') == rpoplpush_policy + + with pytest.raises(ValueError, match="Wrong command or module name: foo.bar.baz"): + dynamic_resolver.resolve('foo.bar.baz') + + with pytest.raises(ValueError, match="Module foo not found"): + dynamic_resolver.resolve('foo.bar') + + with pytest.raises(ValueError, match="Command foo not found in module core"): + dynamic_resolver.resolve('core.foo') + + # Test that policy fallback correctly + static_resolver = StaticPolicyResolver() + with_fallback_dynamic_resolver = dynamic_resolver.with_fallback(static_resolver) + + assert with_fallback_dynamic_resolver.resolve('ft.aggregate').request_policy == RequestPolicy.DEFAULT_KEYLESS + assert with_fallback_dynamic_resolver.resolve('ft.aggregate').response_policy == ResponsePolicy.DEFAULT_KEYLESS + + # Extended chain with one more resolver + mock_command_parser = Mock(spec=CommandsParser) + foo_bar_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS) + + mock_command_parser.get_command_policies.return_value = { + 'foo': { + 'bar': foo_bar_policy, + } + } + another_dynamic_resolver = DynamicPolicyResolver(mock_command_parser) + with_fallback_static_resolver = static_resolver.with_fallback(another_dynamic_resolver) + with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback(with_fallback_static_resolver) + + assert with_double_fallback_dynamic_resolver.resolve('foo.bar') == foo_bar_policy \ No newline at end of file