Skip to content
628 changes: 585 additions & 43 deletions docs/extras/integrations/vectorstores/redis.ipynb

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions libs/langchain/langchain/memory/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any, Dict, List

from langchain.schema.messages import get_buffer_string # noqa: 401


def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""
Expand Down
58 changes: 53 additions & 5 deletions libs/langchain/langchain/utilities/redis.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,64 @@
from __future__ import annotations

import logging
from typing import (
TYPE_CHECKING,
Any,
)
import re
from typing import TYPE_CHECKING, Any, List, Optional, Pattern
from urllib.parse import urlparse

import numpy as np

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from redis.client import Redis as RedisType

logger = logging.getLogger(__name__)

def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
return np.array(array).astype(dtype).tobytes()


class TokenEscaper:
"""
Escape punctuation within an input string.
"""

# Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"

def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)

def escape(self, value: str) -> str:
def escape_symbol(match: re.Match) -> str:
value = match.group(0)
return f"\\{value}"

return self.escaped_chars_re.sub(escape_symbol, value)


def check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
"""Check if the correct Redis modules are installed."""
installed_modules = client.module_list()
installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules
}
for module in required_modules:
if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"]
) >= int(module["ver"]):
return
# otherwise raise error
error_message = (
"Redis cannot be used as a vector database without RediSearch >=2.4"
"Please head to https://redis.io/docs/stack/search/quick_start/"
"to know more about installing the RediSearch module within Redis Stack."
)
logger.error(error_message)
raise ValueError(error_message)


def get_client(redis_url: str, **kwargs: Any) -> RedisType:
Expand Down
Loading