Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ max-line-length = 90
select = A,B,C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T,U,V,W,X,Y,Z,B901,B902,B903,B950
# E226: Missing whitespace around arithmetic operators can help group things together.
# E501,W505: Superseeded by B950 (from Bugbear)
# E704: Allow overloaded function to use Ellipsis as body
# E722: Superseeded by B001 (from Bugbear)
# W503: Mutually exclusive with W504.
ignore = E226,E501,E722,W503,W505
ignore = E226,E501,E704,E722,W503,W505
per-file-ignores =
# S*: Bandit security checks not useful in tests.
tests/*:S
Expand Down
225 changes: 162 additions & 63 deletions aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,129 @@
import asyncio
from typing import Optional
import sys
from typing import Any, Iterable, Literal, Union, overload

import aiomcache

from aiocache.base import BaseCache
from aiocache.base import BaseCache, BaseCacheArgs, _Conn
from aiocache.serializers import JsonSerializer

if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack # noqa: I900

class MemcachedBackend(BaseCache[bytes]):
def __init__(self, host="127.0.0.1", port=11211, pool_size=2, **kwargs):

class MemcachedCache(BaseCache[bytes]):
"""
Memcached cache implementation with the following components as defaults:
- serializer: :class:`aiocache.serializers.JsonSerializer`
- plugins: []
Config options are:
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
:param namespace: string to use as default prefix for the key used in all operations of
the backend. Default is an empty string, "".
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5.
:param endpoint: str with the endpoint to connect to. Default is 127.0.0.1.
:param port: int with the port to connect to. Default is 11211.
:param pool_size: int size for memcached connections pool. Default is 2.
"""

NAME = "memcached"

def __init__(
self,
host: str = "127.0.0.1",
port: int = 11211,
pool_size: int = 2,
**kwargs: Unpack[BaseCacheArgs],
) -> None:
if "serializer" not in kwargs:
kwargs["serializer"] = JsonSerializer()
super().__init__(**kwargs)
self.host = host
self.port = port
self.pool_size = int(pool_size)
self.client = aiomcache.Client(
self.host, self.port, pool_size=self.pool_size
)

async def _get(self, key, encoding="utf-8", _conn=None):
self.client = aiomcache.Client(self.host, self.port, pool_size=self.pool_size)

@overload
async def _get(
self,
key: bytes,
encoding: str = "utf-8",
_conn: Union[_Conn, None] = None,
) -> Union[str, None]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@overload
async def _get(
self, key: bytes, encoding: None, _conn: Union[_Conn, None] = None
) -> Union[bytes, None]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

async def _get(
self,
key: bytes,
encoding: Union[str, None] = "utf-8",
_conn: Union[_Conn, None] = None,
) -> Union[bytes, str, None]:
value = await self.client.get(key)
if encoding is None or value is None:
return value
return value.decode(encoding)

async def _gets(self, key, encoding="utf-8", _conn=None):
async def _gets(
self,
key: Union[bytes, str],
encoding: str = "utf-8",
_conn: Union[_Conn, None] = None,
) -> Union[int, None]:
key = key.encode() if isinstance(key, str) else key
_, token = await self.client.gets(key)
return token

async def _multi_get(self, keys, encoding="utf-8", _conn=None):
values = []
for value in await self.client.multi_get(*keys):
if encoding is None or value is None:
values.append(value)
else:
values.append(value.decode(encoding))
return values

async def _set(self, key, value, ttl=0, _cas_token=None, _conn=None):
@overload
async def _multi_get(
self,
keys: Iterable[bytes],
encoding: str = "utf-8",
_conn: Union[_Conn, None] = None,
) -> list[Union[str, None]]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@overload
async def _multi_get(
self,
keys: Iterable[bytes],
encoding: None,
_conn: Union[_Conn, None] = None,
) -> list[Union[bytes, None]]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

async def _multi_get(
self,
keys: Iterable[bytes],
encoding: Union[str, None] = "utf-8",
_conn: Union[_Conn, None] = None,
) -> Union[
list[Union[str, None]],
list[Union[bytes, None]],
]:
raw_values = await self.client.multi_get(*keys)
if encoding is None:
return list(raw_values)

return [
None if value is None else value.decode(encoding) for value in raw_values
]

async def _set(
self,
key: bytes,
value: Union[str, bytes],
ttl: int = 0,
_cas_token: Union[int, None] = None,
_conn: Union[_Conn, None] = None,
) -> bool:
value = value.encode() if isinstance(value, str) else value
if _cas_token is not None:
return await self._cas(key, value, _cas_token, ttl=ttl, _conn=_conn)
Expand All @@ -46,10 +132,23 @@
except aiomcache.exceptions.ValidationException as e:
raise TypeError("aiomcache error: {}".format(str(e)))

async def _cas(self, key, value, token, ttl=None, _conn=None):
async def _cas(
self,
key: bytes,
value: Union[str, bytes],
token: int,
ttl: Union[int, None] = None,
_conn: Union[_Conn, None] = None,
) -> bool:
value = str.encode(value) if isinstance(value, str) else value
return await self.client.cas(key, value, token, exptime=ttl or 0)

async def _multi_set(self, pairs, ttl=0, _conn=None):
async def _multi_set(
self,
pairs: Iterable[tuple[bytes, Union[str, bytes]]],
ttl: int = 0,
_conn: Union[_Conn, None] = None,
) -> bool:
tasks = []
for key, value in pairs:
value = str.encode(value) if isinstance(value, str) else value
Expand All @@ -62,21 +161,31 @@

return True

async def _add(self, key, value, ttl=0, _conn=None):
value = str.encode(value) if isinstance(value, str) else value
async def _add(
self,
key: bytes,
value: Union[str, bytes],
ttl: int = 0,
_conn: Union[_Conn, None] = None,
) -> bool:
value_bytes = str.encode(value) if isinstance(value, str) else value
try:
ret = await self.client.add(key, value, exptime=ttl or 0)
ret = await self.client.add(key, value_bytes, exptime=ttl or 0)
except aiomcache.exceptions.ValidationException as e:
raise TypeError("aiomcache error: {}".format(str(e)))
if not ret:
raise ValueError("Key {} already exists, use .set to update the value".format(key))
raise ValueError(
"Key {!r} already exists, use .set to update the value".format(key)
)

return True

async def _exists(self, key, _conn=None):
async def _exists(self, key: bytes, _conn: Union[_Conn, None] = None) -> bool:
return await self.client.append(key, b"")

async def _increment(self, key, delta, _conn=None):
async def _increment(
self, key: bytes, delta: int, _conn: Union[_Conn, None] = None
) -> int:
incremented = None
try:
if delta > 0:
Expand All @@ -91,66 +200,56 @@

return incremented or delta

async def _expire(self, key, ttl, _conn=None):
async def _expire(
self, key: bytes, ttl: int, _conn: Union[_Conn, None] = None
) -> bool:
return await self.client.touch(key, ttl)

async def _delete(self, key, _conn=None):
async def _delete(
self, key: bytes, _conn: Union[str, None] = None
) -> Literal[1, 0]:
return 1 if await self.client.delete(key) else 0

async def _clear(self, namespace=None, _conn=None):
async def _clear(
self, namespace: Union[str, None] = None, _conn: Union[_Conn, None] = None
) -> bool:
if namespace:
raise ValueError("MemcachedBackend doesnt support flushing by namespace")
raise ValueError("MemcachedCache doesnt support flushing by namespace")
else:
await self.client.flush_all()
return True

async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
async def _raw(
self,
command: str,
*args: Any,
encoding: str = "utf-8",
_conn: Union[_Conn, None] = None,
**kwargs: Any,
) -> Any:
value = await getattr(self.client, command)(*args, **kwargs)
if command in {"get", "multi_get"}:
if encoding is not None and value is not None:
return value.decode(encoding)
return value

async def _redlock_release(self, key, _):
async def _redlock_release(self, key: bytes, _: Any) -> Literal[1, 0]:
# Not ideal, should check the value coincides first but this would introduce
# race conditions
return await self._delete(key)

async def _close(self, *args, _conn=None, **kwargs):
async def _close(
self, *args: Any, _conn: Union[_Conn, None] = None, **kwargs: Any
) -> None:
await self.client.close()

def build_key(self, key: str, namespace: Optional[str] = None) -> bytes:
def build_key(self, key: str, namespace: Union[str, None] = None) -> bytes:
ns_key = self._str_build_key(key, namespace).replace(" ", "_")
return str.encode(ns_key)


class MemcachedCache(MemcachedBackend):
"""
Memcached cache implementation with the following components as defaults:
- serializer: :class:`aiocache.serializers.JsonSerializer`
- plugins: []
Config options are:
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
:param namespace: string to use as default prefix for the key used in all operations of
the backend. Default is an empty string, "".
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5.
:param endpoint: str with the endpoint to connect to. Default is 127.0.0.1.
:param port: int with the port to connect to. Default is 11211.
:param pool_size: int size for memcached connections pool. Default is 2.
"""

NAME = "memcached"

def __init__(self, serializer=None, **kwargs):
super().__init__(serializer=serializer or JsonSerializer(), **kwargs)

@classmethod
def parse_uri_path(cls, path):
def parse_uri_path(cls, path: str) -> dict[Any, Any]:
return {}

def __repr__(self): # pragma: no cover
def __repr__(self) -> str: # pragma: no cover
return "MemcachedCache ({}:{})".format(self.host, self.port)
Loading
Loading