Skip to content

Commit 5908177

Browse files
author
Haiyang Shi
committed
[Feature] KVCache: add Pris connector
Signed-off-by: Haiyang Shi <[email protected]>
1 parent 2a40822 commit 5908177

File tree

5 files changed

+307
-1
lines changed

5 files changed

+307
-1
lines changed

python/aibrix_kvcache/.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
hooks:
2323
- id: codespell
2424
additional_dependencies: ['tomli']
25-
args: ['--toml', *pyproject_toml]
25+
args: ['--toml', *pyproject_toml, '-L', "pris"]
2626
files: *aibrix_kvcache_files
2727
- repo: https://github.com/pre-commit/mirrors-clang-format
2828
rev: v20.1.3

python/aibrix_kvcache/aibrix_kvcache/envs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@
119119
AIBRIX_KV_CACHE_OL_HPKV_LOCAL_PORT: int = 12345
120120
AIBRIX_KV_CACHE_OL_HPKV_USE_GDR: bool = True
121121

122+
# Pris Env Vars
123+
AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR: str = "127.0.0.1"
124+
AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT: int = 6379
125+
AIBRIX_KV_CACHE_OL_PRIS_USE_MPUT_MGET: bool = False
126+
AIBRIX_KV_CACHE_OL_PRIS_PASSWORD: str = ""
127+
122128
# RDMA Auto-Detection Env Vars
123129
# Defines the range of valid GIDs. Similar to NVSHMEM_IB_ADDR_RANGE
124130
# for NVSHMEM. It must be a valid CIDR.
@@ -343,6 +349,20 @@
343349
os.getenv("AIBRIX_KV_CACHE_OL_HPKV_USE_GDR", "1").strip().lower()
344350
in ("1", "true")
345351
),
352+
# ================== PRIS Env Vars ==================
353+
"AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR": lambda: (
354+
os.getenv("AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR", "127.0.0.1").strip()
355+
),
356+
"AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT": lambda: int(
357+
os.getenv("AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT", "6379")
358+
),
359+
"AIBRIX_KV_CACHE_OL_PRIS_USE_MPUT_MGET": lambda: (
360+
os.getenv("AIBRIX_KV_CACHE_OL_PRIS_USE_MPUT_MGET", "0").strip().lower()
361+
in ("1", "true")
362+
),
363+
"AIBRIX_KV_CACHE_OL_PRIS_PASSWORD": lambda: (
364+
os.getenv("AIBRIX_KV_CACHE_OL_PRIS_PASSWORD", "").strip()
365+
),
346366
# ================== RDMA Auto-Detection Env Vars ==================
347367
"AIBRIX_KV_CACHE_OL_TRANSPORT_RDMA_ADDR_RANGE": lambda: (
348368
os.getenv(

python/aibrix_kvcache/aibrix_kvcache/l2/connectors/connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def create(
9292
from .hpkv import HPKVConnector
9393

9494
return HPKVConnector.from_envs(conn_id, executor, **kwargs)
95+
elif backend_name == "PRIS":
96+
from .pris import PrisConnector
97+
98+
return PrisConnector.from_envs(conn_id, executor, **kwargs)
9599
elif backend_name == "MOCK":
96100
from .mock import MockConnector
97101

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Copyright 2024 The Aibrix Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from concurrent.futures import Executor
16+
from dataclasses import dataclass
17+
from typing import Any, Dict, List, Sequence, Tuple
18+
19+
import pris._pris as Pris
20+
import torch
21+
from pris.pris_client import PrisClient
22+
23+
from ... import envs
24+
from ...common import AsyncBase
25+
from ...memory import MemoryRegion
26+
from ...status import Status, StatusCodes
27+
from . import Connector, ConnectorFeature, ConnectorRegisterDescriptor
28+
29+
30+
@dataclass
31+
class PrisConfig:
32+
"""Pris config.
33+
Args:
34+
remote_addr (str): remote address
35+
remote_port (int): remote port
36+
password (str): password
37+
"""
38+
39+
remote_addr: str
40+
remote_port: int
41+
password: str
42+
43+
44+
@dataclass
45+
class PrisRegisterDescriptor(ConnectorRegisterDescriptor):
46+
"""Pris register descriptor."""
47+
48+
reg_buf: int
49+
50+
51+
@AsyncBase.async_wrap(
52+
exists="_exists",
53+
get="_get",
54+
put="_put",
55+
delete="_delete",
56+
mget="_mget",
57+
mput="_mput",
58+
)
59+
class PrisConnector(Connector[bytes, torch.Tensor], AsyncBase):
60+
"""Pris connector."""
61+
62+
def __init__(
63+
self,
64+
config: PrisConfig,
65+
key_suffix: str,
66+
executor: Executor,
67+
):
68+
super().__init__(executor)
69+
self.config = config
70+
self.key_suffix = key_suffix
71+
self.conn: PrisClient | None = None
72+
self._register_cache: Dict[int, PrisRegisterDescriptor] = {}
73+
74+
@classmethod
75+
def from_envs(
76+
cls, conn_id: str, executor: Executor, **kwargs
77+
) -> "PrisConnector":
78+
"""Create a connector from environment variables."""
79+
remote_addr = kwargs.get(
80+
"addr", envs.AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR
81+
)
82+
remote_port = kwargs.get(
83+
"port", envs.AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT
84+
)
85+
86+
config = PrisConfig(
87+
remote_addr=remote_addr,
88+
remote_port=remote_port,
89+
password=envs.AIBRIX_KV_CACHE_OL_PRIS_PASSWORD,
90+
)
91+
return cls(config, conn_id, executor)
92+
93+
@property
94+
def name(self) -> str:
95+
return "PRIS"
96+
97+
@property
98+
def feature(self) -> ConnectorFeature:
99+
feature = ConnectorFeature(
100+
rdma=True,
101+
mput_mget=envs.AIBRIX_KV_CACHE_OL_PRIS_USE_MPUT_MGET,
102+
)
103+
return feature
104+
105+
def __del__(self) -> None:
106+
self.close()
107+
108+
def _key(self, key: bytes) -> str:
109+
return key.hex() + self.key_suffix
110+
111+
@Status.capture_exception
112+
def open(self) -> Status:
113+
"""Open a connection."""
114+
if self.conn is None:
115+
self.conn = PrisClient(
116+
raddr=self.config.remote_addr,
117+
rport=self.config.remote_port,
118+
password=self.config.password,
119+
)
120+
return Status.ok()
121+
122+
@Status.capture_exception
123+
def close(self) -> Status:
124+
"""Close a connection."""
125+
if self.conn is not None:
126+
for _, desc in self._register_cache.items():
127+
self._deregister_mr(desc)
128+
self._register_cache.clear()
129+
130+
self.conn.close()
131+
self.conn = None
132+
return Status.ok()
133+
134+
@Status.capture_exception
135+
def register_slabs(self, slabs: List[torch.Tensor]) -> Status:
136+
assert self.conn is not None
137+
for slab in slabs:
138+
addr = slab.data_ptr()
139+
length = slab.numel()
140+
reg_buf = self.conn.reg_memory(addr, length)
141+
if reg_buf == 0:
142+
return Status(StatusCodes.INVALID)
143+
desc = PrisRegisterDescriptor(reg_buf)
144+
self._register_cache[addr] = desc
145+
return Status.ok(desc)
146+
147+
def _get_register_descriptor(
148+
self, mr: MemoryRegion
149+
) -> Status[PrisRegisterDescriptor]:
150+
slab = mr.slab
151+
addr = slab.data_ptr()
152+
if addr not in self._register_cache:
153+
return Status(
154+
StatusCodes.INVALID, f"Slab(addr={addr}) hasn't been registered"
155+
)
156+
return Status.ok(self._register_cache[addr])
157+
158+
def _deregister_mr(self, desc: PrisRegisterDescriptor) -> None:
159+
assert self.conn is not None
160+
if desc.reg_buf != 0:
161+
self.conn.dereg_memory(desc.reg_buf)
162+
desc.reg_buf = 0
163+
164+
@Status.capture_exception
165+
def _exists(self, key: bytes) -> Status:
166+
"""Check if key is in the store."""
167+
assert self.conn is not None
168+
if self.conn.exists(self._key(key)) == 0:
169+
return Status.ok()
170+
return Status(StatusCodes.NOT_FOUND)
171+
172+
@Status.capture_exception
173+
def _get(self, key: bytes, mr: MemoryRegion) -> Status:
174+
"""Get a value."""
175+
assert self.conn is not None
176+
desc_status = self._get_register_descriptor(mr)
177+
if not desc_status.is_ok():
178+
return Status(desc_status)
179+
desc = desc_status.get()
180+
sgl = Pris.SGL(mr.data_ptr(), mr.length, desc.reg_buf)
181+
if self.conn.get(self._key(key), sgl, mr.length) != 0:
182+
return Status(StatusCodes.ERROR)
183+
return Status.ok()
184+
185+
@Status.capture_exception
186+
def _put(self, key: bytes, mr: MemoryRegion) -> Status:
187+
"""Put a key value pair"""
188+
assert self.conn is not None
189+
desc_status = self._get_register_descriptor(mr)
190+
if not desc_status.is_ok():
191+
return Status(desc_status)
192+
desc = desc_status.get()
193+
sgl = Pris.SGL(mr.data_ptr(), mr.length, desc.reg_buf)
194+
if self.conn.set(self._key(key), sgl) != 0:
195+
return Status(StatusCodes.ERROR)
196+
return Status.ok()
197+
198+
def get_batches(
199+
self,
200+
keys: Sequence[Any],
201+
mrs: Sequence[MemoryRegion],
202+
batch_size: int,
203+
) -> Sequence[Sequence[Tuple[bytes, MemoryRegion]]]:
204+
lists: List[List[Tuple[bytes, MemoryRegion]]] = []
205+
for key, mr in zip(keys, mrs):
206+
if len(lists) == 0 or len(lists[-1]) >= batch_size:
207+
lists.append([(key, mr)])
208+
else:
209+
lists[-1].append((key, mr))
210+
return lists
211+
212+
@Status.capture_exception
213+
def _mget(
214+
self, keys: Sequence[bytes], mrs: Sequence[MemoryRegion]
215+
) -> Sequence[Status]:
216+
assert self.conn is not None
217+
sgls: List[Pris.SGL] = []
218+
cache_keys: List[str] = []
219+
value_lens: List[int] = []
220+
op_status: List[Status] = [Status.ok()] * len(keys)
221+
for i, mr in enumerate(mrs):
222+
desc_status = self._get_register_descriptor(mr)
223+
if not desc_status.is_ok():
224+
for j in range(i, len(keys)):
225+
op_status[j] = Status(desc_status)
226+
break
227+
desc = desc_status.get()
228+
sgl = Pris.SGL(mr.data_ptr(), mr.length, desc.reg_buf)
229+
sgls.append(sgl)
230+
cache_keys.append(self._key(keys[i]))
231+
value_lens.append(mr.length)
232+
233+
if len(sgls) == 0:
234+
return op_status
235+
236+
status, details = self.conn.mget(cache_keys, sgls, value_lens)
237+
if status == 0:
238+
return op_status
239+
else:
240+
for i, s in enumerate(details):
241+
if s != 0:
242+
op_status[i] = Status(StatusCodes.ERROR)
243+
return op_status
244+
245+
@Status.capture_exception
246+
def _mput(
247+
self, keys: Sequence[bytes], mrs: Sequence[MemoryRegion]
248+
) -> Sequence[Status]:
249+
assert self.conn is not None
250+
sgls: List[Pris.SGL] = []
251+
cache_keys: List[str] = []
252+
op_status: List[Status] = [Status.ok()] * len(keys)
253+
for i, mr in enumerate(mrs):
254+
desc_status = self._get_register_descriptor(mr)
255+
if not desc_status.is_ok():
256+
for j in range(i, len(keys)):
257+
op_status[j] = Status(desc_status)
258+
break
259+
desc = desc_status.get()
260+
sgl = Pris.SGL(mr.data_ptr(), mr.length, desc.reg_buf)
261+
sgls.append(sgl)
262+
cache_keys.append(self._key(keys[i]))
263+
264+
if len(sgls) == 0:
265+
return op_status
266+
267+
status, details = self.conn.mset(cache_keys, sgls)
268+
if status == 0:
269+
return op_status
270+
else:
271+
for i, s in enumerate(details):
272+
if s != 0:
273+
op_status[i] = Status(StatusCodes.ERROR)
274+
return op_status
275+
276+
@Status.capture_exception
277+
def _delete(self, key: bytes) -> Status:
278+
"""Delete a key."""
279+
assert self.conn is not None
280+
self.conn.delete(self._key(key))
281+
return Status.ok()

python/aibrix_kvcache/requirements/core.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ rocksdict
2828
# infinistore >= 0.2.35
2929
# --extra-index-url https://scqq9isgq31i0fb8nt4eg.apigateway-cn-beijing.volceapi.com/simple/
3030
# hpkv >= 0.0.1
31+
# pris >= 0.0.4
3132
# pyverbs

0 commit comments

Comments
 (0)