Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 202 additions & 2 deletions redis/_parsers/commands.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,43 @@
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

from redis.exceptions import RedisError, ResponseError
from redis.exceptions import RedisError, ResponseError, IncorrectPolicyType
from redis.utils import str_if_bytes

if TYPE_CHECKING:
from redis.asyncio.cluster import ClusterNode

class RequestPolicy(Enum):
ALL_NODES = 'all_nodes'
ALL_SHARDS = 'all_shards'
MULTI_SHARD = 'multi_shard'
SPECIAL = 'special'
DEFAULT_KEYLESS = 'default_keyless'
DEFAULT_KEYED = 'default_keyed'

class ResponsePolicy(Enum):
ONE_SUCCEEDED = 'one_succeeded'
ALL_SUCCEEDED = 'all_succeeded'
AGG_LOGICAL_AND = 'agg_logical_and'
AGG_LOGICAL_OR = 'agg_logical_or'
AGG_MIN = 'agg_min'
AGG_MAX = 'agg_max'
AGG_SUM = 'agg_sum'
SPECIAL = 'special'
DEFAULT_KEYLESS = 'default_keyless'
DEFAULT_KEYED = 'default_keyed'

class CommandPolicies:
def __init__(
self,
request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS,
response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS
):
self.request_policy = request_policy
self.response_policy = response_policy

PolicyRecords = dict[str, dict[str, CommandPolicies]]

class AbstractCommandsParser:
def _get_pubsub_keys(self, *args):
Expand Down Expand Up @@ -64,7 +96,8 @@ class CommandsParser(AbstractCommandsParser):

def __init__(self, redis_connection):
self.commands = {}
self.initialize(redis_connection)
self.redis_connection = redis_connection
self.initialize(self.redis_connection)

def initialize(self, r):
commands = r.command()
Expand Down Expand Up @@ -169,6 +202,173 @@ def _get_moveable_keys(self, redis_conn, *args):
raise e
return keys

def _is_keyless_command(self, command_name: str, subcommand_name: Optional[str]=None) -> bool:
"""
Determines whether a given command or subcommand is considered "keyless".

A keyless command does not operate on specific keys, which is determined based
on the first key position in the command or subcommand details. If the command
or subcommand's first key position is zero or negative, it is treated as keyless.

Parameters:
command_name: str
The name of the command to check.
subcommand_name: Optional[str], default=None
The name of the subcommand to check, if applicable. If not provided,
the check is performed only on the command.

Returns:
bool
True if the specified command or subcommand is considered keyless,
False otherwise.

Raises:
ValueError
If the specified subcommand is not found within the command or the
specified command does not exist in the available commands.
"""
if subcommand_name:
for subcommand in self.commands.get(command_name)['subcommands']:
if str_if_bytes(subcommand[0]) == subcommand_name:
parsed_subcmd = self.parse_subcommand(subcommand)
return parsed_subcmd['first_key_pos'] <= 0
raise ValueError(f"Subcommand {subcommand_name} not found in command {command_name}")
else:
command_details = self.commands.get(command_name, None)
if command_details is not None:
return command_details['first_key_pos'] <= 0

raise ValueError(f"Command {command_name} not found in commands")

def get_command_policies(self) -> PolicyRecords:
"""
Retrieve and process the command policies for all commands and subcommands.

This method traverses through commands and subcommands, extracting policy details
from associated data structures and constructing a dictionary of commands with their
associated policies. It supports nested data structures and handles both main commands
and their subcommands.

Returns:
PolicyRecords: A collection of commands and subcommands associated with their
respective policies.

Raises:
IncorrectPolicyType: If an invalid policy type is encountered during policy extraction.
"""
command_with_policies = {}

def extract_policies(data, command_name, module_name):
"""
Recursively extract policies from nested data structures.

Args:
data: The data structure to search (can be list, dict, str, bytes, etc.)
command_name: The command name to associate with found policies
"""
if isinstance(data, (str, bytes)):
# Decode bytes to string if needed
policy = data.decode() if isinstance(data, bytes) else data

# Check if this is a policy string
if policy.startswith('request_policy') or policy.startswith('response_policy'):
if policy.startswith('request_policy'):
policy_type = policy.split(':')[1]

try:
command_with_policies[module_name][command_name].request_policy = RequestPolicy(policy_type)
except ValueError:
raise IncorrectPolicyType(f"Incorrect request policy type: {policy_type}")

if policy.startswith('response_policy'):
policy_type = policy.split(':')[1]

try:
command_with_policies[module_name][command_name].response_policy = ResponsePolicy(policy_type)
except ValueError:
raise IncorrectPolicyType(f"Incorrect response policy type: {policy_type}")

elif isinstance(data, list):
# For lists, recursively process each element
for item in data:
extract_policies(item, command_name, module_name)

elif isinstance(data, dict):
# For dictionaries, recursively process each value
for value in data.values():
extract_policies(value, command_name, module_name)

for command, details in self.commands.items():
# Check whether the command has keys
is_keyless = self._is_keyless_command(command)

if is_keyless:
default_request_policy = RequestPolicy.DEFAULT_KEYLESS
default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
else:
default_request_policy = RequestPolicy.DEFAULT_KEYED
default_response_policy = ResponsePolicy.DEFAULT_KEYED

# Check if it's a core or module command
split_name = command.split('.')

if len(split_name) > 1:
module_name = split_name[0]
command_name = split_name[1]
else:
module_name = 'core'
command_name = split_name[0]

# Create a CommandPolicies object with default policies on the new command.
if command_with_policies.get(module_name, None) is None:
command_with_policies[module_name] = {command_name: CommandPolicies(
request_policy=default_request_policy,
response_policy=default_response_policy
)}
else:
command_with_policies[module_name][command_name] = CommandPolicies(
request_policy=default_request_policy,
response_policy=default_response_policy
)

tips = details.get('tips')
subcommands = details.get('subcommands')

# Process tips for the main command
if tips:
extract_policies(tips, command_name, module_name)

# Process subcommands
if subcommands:
for subcommand_details in subcommands:
# Get the subcommand name (first element)
subcmd_name = subcommand_details[0]
if isinstance(subcmd_name, bytes):
subcmd_name = subcmd_name.decode()

# Check whether the subcommand has keys
is_keyless = self._is_keyless_command(command, subcmd_name)

if is_keyless:
default_request_policy = RequestPolicy.DEFAULT_KEYLESS
default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
else:
default_request_policy = RequestPolicy.DEFAULT_KEYED
default_response_policy = ResponsePolicy.DEFAULT_KEYED

subcmd_name = subcmd_name.replace('|', ' ')

# Create a CommandPolicies object with default policies on the new command.
command_with_policies[module_name][subcmd_name] = CommandPolicies(
request_policy=default_request_policy,
response_policy=default_response_policy
)

# Recursively extract policies from the rest of the subcommand details
for subcommand_detail in subcommand_details[1:]:
extract_policies(subcommand_detail, subcmd_name, module_name)

return command_with_policies

class AsyncCommandsParser(AbstractCommandsParser):
"""
Expand Down
Loading
Loading