diff --git a/python/aibrix_kvcache/.pre-commit-config.yaml b/python/aibrix_kvcache/.pre-commit-config.yaml index 82e4a6caf..5f1a5f2be 100644 --- a/python/aibrix_kvcache/.pre-commit-config.yaml +++ b/python/aibrix_kvcache/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: hooks: - id: codespell additional_dependencies: ['tomli'] - args: ['--toml', *pyproject_toml] + args: ['--toml', *pyproject_toml, '-L', "pris"] files: *aibrix_kvcache_files - repo: https://github.com/pre-commit/mirrors-clang-format rev: v20.1.3 diff --git a/python/aibrix_kvcache/aibrix_kvcache/envs.py b/python/aibrix_kvcache/aibrix_kvcache/envs.py index 2f6148c19..dc6404642 100644 --- a/python/aibrix_kvcache/aibrix_kvcache/envs.py +++ b/python/aibrix_kvcache/aibrix_kvcache/envs.py @@ -119,6 +119,12 @@ AIBRIX_KV_CACHE_OL_HPKV_LOCAL_PORT: int = 12345 AIBRIX_KV_CACHE_OL_HPKV_USE_GDR: bool = True + # Pris Env Vars + AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR: str = "127.0.0.1" + AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT: int = 6379 + AIBRIX_KV_CACHE_OL_PRIS_USE_MPUT_MGET: bool = False + AIBRIX_KV_CACHE_OL_PRIS_PASSWORD: str = "" + # RDMA Auto-Detection Env Vars # Defines the range of valid GIDs. Similar to NVSHMEM_IB_ADDR_RANGE # for NVSHMEM. It must be a valid CIDR. @@ -343,6 +349,20 @@ os.getenv("AIBRIX_KV_CACHE_OL_HPKV_USE_GDR", "1").strip().lower() in ("1", "true") ), + # ================== PRIS Env Vars ================== + "AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR": lambda: ( + os.getenv("AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR", "127.0.0.1").strip() + ), + "AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT": lambda: int( + os.getenv("AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT", "6379") + ), + "AIBRIX_KV_CACHE_OL_PRIS_USE_MPUT_MGET": lambda: ( + os.getenv("AIBRIX_KV_CACHE_OL_PRIS_USE_MPUT_MGET", "0").strip().lower() + in ("1", "true") + ), + "AIBRIX_KV_CACHE_OL_PRIS_PASSWORD": lambda: ( + os.getenv("AIBRIX_KV_CACHE_OL_PRIS_PASSWORD", "").strip() + ), # ================== RDMA Auto-Detection Env Vars ================== "AIBRIX_KV_CACHE_OL_TRANSPORT_RDMA_ADDR_RANGE": lambda: ( os.getenv( diff --git a/python/aibrix_kvcache/aibrix_kvcache/l2/connectors/connector.py b/python/aibrix_kvcache/aibrix_kvcache/l2/connectors/connector.py index 83e129fef..88b3ff59c 100644 --- a/python/aibrix_kvcache/aibrix_kvcache/l2/connectors/connector.py +++ b/python/aibrix_kvcache/aibrix_kvcache/l2/connectors/connector.py @@ -92,6 +92,10 @@ def create( from .hpkv import HPKVConnector return HPKVConnector.from_envs(conn_id, executor, **kwargs) + elif backend_name == "PRIS": + from .pris import PrisConnector + + return PrisConnector.from_envs(conn_id, executor, **kwargs) elif backend_name == "MOCK": from .mock import MockConnector diff --git a/python/aibrix_kvcache/aibrix_kvcache/l2/connectors/pris.py b/python/aibrix_kvcache/aibrix_kvcache/l2/connectors/pris.py new file mode 100644 index 000000000..b0d977abf --- /dev/null +++ b/python/aibrix_kvcache/aibrix_kvcache/l2/connectors/pris.py @@ -0,0 +1,281 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import Executor +from dataclasses import dataclass +from typing import Any, Dict, List, Sequence, Tuple + +import pris._pris as Pris +import torch +from pris.pris_client import PrisClient + +from ... import envs +from ...common import AsyncBase +from ...memory import MemoryRegion +from ...status import Status, StatusCodes +from . import Connector, ConnectorFeature, ConnectorRegisterDescriptor + + +@dataclass +class PrisConfig: + """Pris config. + Args: + remote_addr (str): remote address + remote_port (int): remote port + password (str): password + """ + + remote_addr: str + remote_port: int + password: str + + +@dataclass +class PrisRegisterDescriptor(ConnectorRegisterDescriptor): + """Pris register descriptor.""" + + reg_buf: int + + +@AsyncBase.async_wrap( + exists="_exists", + get="_get", + put="_put", + delete="_delete", + mget="_mget", + mput="_mput", +) +class PrisConnector(Connector[bytes, torch.Tensor], AsyncBase): + """Pris connector.""" + + def __init__( + self, + config: PrisConfig, + key_suffix: str, + executor: Executor, + ): + super().__init__(executor) + self.config = config + self.key_suffix = key_suffix + self.conn: PrisClient | None = None + self._register_cache: Dict[int, PrisRegisterDescriptor] = {} + + @classmethod + def from_envs( + cls, conn_id: str, executor: Executor, **kwargs + ) -> "PrisConnector": + """Create a connector from environment variables.""" + remote_addr = kwargs.get( + "addr", envs.AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR + ) + remote_port = kwargs.get( + "port", envs.AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT + ) + + config = PrisConfig( + remote_addr=remote_addr, + remote_port=remote_port, + password=envs.AIBRIX_KV_CACHE_OL_PRIS_PASSWORD, + ) + return cls(config, conn_id, executor) + + @property + def name(self) -> str: + return "PRIS" + + @property + def feature(self) -> ConnectorFeature: + feature = ConnectorFeature( + rdma=True, + mput_mget=envs.AIBRIX_KV_CACHE_OL_PRIS_USE_MPUT_MGET, + ) + return feature + + def __del__(self) -> None: + self.close() + + def _key(self, key: bytes) -> str: + return key.hex() + self.key_suffix + + @Status.capture_exception + def open(self) -> Status: + """Open a connection.""" + if self.conn is None: + self.conn = PrisClient( + raddr=self.config.remote_addr, + rport=self.config.remote_port, + password=self.config.password, + ) + return Status.ok() + + @Status.capture_exception + def close(self) -> Status: + """Close a connection.""" + if self.conn is not None: + for _, desc in self._register_cache.items(): + self._deregister_mr(desc) + self._register_cache.clear() + + self.conn.close() + self.conn = None + return Status.ok() + + @Status.capture_exception + def register_slabs(self, slabs: List[torch.Tensor]) -> Status: + assert self.conn is not None + for slab in slabs: + addr = slab.data_ptr() + length = slab.numel() + reg_buf = self.conn.reg_memory(addr, length) + if reg_buf == 0: + return Status(StatusCodes.INVALID) + desc = PrisRegisterDescriptor(reg_buf) + self._register_cache[addr] = desc + return Status.ok(desc) + + def _get_register_descriptor( + self, mr: MemoryRegion + ) -> Status[PrisRegisterDescriptor]: + slab = mr.slab + addr = slab.data_ptr() + if addr not in self._register_cache: + return Status( + StatusCodes.INVALID, f"Slab(addr={addr}) hasn't been registered" + ) + return Status.ok(self._register_cache[addr]) + + def _deregister_mr(self, desc: PrisRegisterDescriptor) -> None: + assert self.conn is not None + if desc.reg_buf != 0: + self.conn.dereg_memory(desc.reg_buf) + desc.reg_buf = 0 + + @Status.capture_exception + def _exists(self, key: bytes) -> Status: + """Check if key is in the store.""" + assert self.conn is not None + if self.conn.exists(self._key(key)) == 0: + return Status.ok() + return Status(StatusCodes.NOT_FOUND) + + @Status.capture_exception + def _get(self, key: bytes, mr: MemoryRegion) -> Status: + """Get a value.""" + assert self.conn is not None + desc_status = self._get_register_descriptor(mr) + if not desc_status.is_ok(): + return Status(desc_status) + desc = desc_status.get() + sgl = Pris.SGL(mr.data_ptr(), mr.length, desc.reg_buf) + if self.conn.get(self._key(key), sgl, mr.length) != 0: + return Status(StatusCodes.ERROR) + return Status.ok() + + @Status.capture_exception + def _put(self, key: bytes, mr: MemoryRegion) -> Status: + """Put a key value pair""" + assert self.conn is not None + desc_status = self._get_register_descriptor(mr) + if not desc_status.is_ok(): + return Status(desc_status) + desc = desc_status.get() + sgl = Pris.SGL(mr.data_ptr(), mr.length, desc.reg_buf) + if self.conn.set(self._key(key), sgl) != 0: + return Status(StatusCodes.ERROR) + return Status.ok() + + def get_batches( + self, + keys: Sequence[Any], + mrs: Sequence[MemoryRegion], + batch_size: int, + ) -> Sequence[Sequence[Tuple[bytes, MemoryRegion]]]: + lists: List[List[Tuple[bytes, MemoryRegion]]] = [] + for key, mr in zip(keys, mrs): + if len(lists) == 0 or len(lists[-1]) >= batch_size: + lists.append([(key, mr)]) + else: + lists[-1].append((key, mr)) + return lists + + @Status.capture_exception + def _mget( + self, keys: Sequence[bytes], mrs: Sequence[MemoryRegion] + ) -> Sequence[Status]: + assert self.conn is not None + sgls: List[Pris.SGL] = [] + cache_keys: List[str] = [] + value_lens: List[int] = [] + op_status: List[Status] = [Status.ok()] * len(keys) + for i, mr in enumerate(mrs): + desc_status = self._get_register_descriptor(mr) + if not desc_status.is_ok(): + for j in range(i, len(keys)): + op_status[j] = Status(desc_status) + break + desc = desc_status.get() + sgl = Pris.SGL(mr.data_ptr(), mr.length, desc.reg_buf) + sgls.append(sgl) + cache_keys.append(self._key(keys[i])) + value_lens.append(mr.length) + + if len(sgls) == 0: + return op_status + + status, details = self.conn.mget(cache_keys, sgls, value_lens) + if status == 0: + return op_status + else: + for i, s in enumerate(details): + if s != 0: + op_status[i] = Status(StatusCodes.ERROR) + return op_status + + @Status.capture_exception + def _mput( + self, keys: Sequence[bytes], mrs: Sequence[MemoryRegion] + ) -> Sequence[Status]: + assert self.conn is not None + sgls: List[Pris.SGL] = [] + cache_keys: List[str] = [] + op_status: List[Status] = [Status.ok()] * len(keys) + for i, mr in enumerate(mrs): + desc_status = self._get_register_descriptor(mr) + if not desc_status.is_ok(): + for j in range(i, len(keys)): + op_status[j] = Status(desc_status) + break + desc = desc_status.get() + sgl = Pris.SGL(mr.data_ptr(), mr.length, desc.reg_buf) + sgls.append(sgl) + cache_keys.append(self._key(keys[i])) + + if len(sgls) == 0: + return op_status + + status, details = self.conn.mset(cache_keys, sgls) + if status == 0: + return op_status + else: + for i, s in enumerate(details): + if s != 0: + op_status[i] = Status(StatusCodes.ERROR) + return op_status + + @Status.capture_exception + def _delete(self, key: bytes) -> Status: + """Delete a key.""" + assert self.conn is not None + self.conn.delete(self._key(key)) + return Status.ok() diff --git a/python/aibrix_kvcache/requirements/core.txt b/python/aibrix_kvcache/requirements/core.txt index 35c4d8a5a..b25047c67 100644 --- a/python/aibrix_kvcache/requirements/core.txt +++ b/python/aibrix_kvcache/requirements/core.txt @@ -28,4 +28,5 @@ rocksdict # infinistore >= 0.2.35 # --extra-index-url https://scqq9isgq31i0fb8nt4eg.apigateway-cn-beijing.volceapi.com/simple/ # hpkv >= 0.0.1 +# pris >= 0.0.4 # pyverbs