Skip to content

Commit b996af0

Browse files
committed
feat: add page-aligned tensor creator for host KV cache.
1 parent 7489584 commit b996af0

File tree

6 files changed

+130
-73
lines changed

6 files changed

+130
-73
lines changed

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,7 @@ class ClientStreamReceiver : public brpc::StreamInputHandler {
351351

352352
~ClientStreamReceiver() {
353353
if (!promise_set_.exchange(true)) {
354-
try {
355-
close_promise_.set_value();
356-
} catch (const std::exception& e) {
357-
LOG(WARNING) << "Exception in destructor: " << e.what();
358-
}
354+
close_promise_.set_value();
359355
}
360356
}
361357

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -442,11 +442,7 @@ class ServerStreamHandler : public brpc::StreamInputHandler {
442442
public:
443443
~ServerStreamHandler() {
444444
if (!promise_set_.exchange(true)) {
445-
try {
446-
close_promise_.set_value();
447-
} catch (const std::exception& e) {
448-
LOG(WARNING) << "Exception in destructor: " << e.what();
449-
}
445+
close_promise_.set_value();
450446
}
451447
}
452448

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,18 @@ bool KVCacheStore::init(const StoreConfig& config,
5555
LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_;
5656

5757
if (config_.protocol == "rdma") {
58-
for (int block = 0; block < host_kv_caches_->size(); block++) {
59-
void* key_cache = static_cast<char*>(
60-
host_kv_caches_->at(block).get_k_cache().data_ptr());
61-
62-
auto register_k_result = client_ptr_->RegisterLocalMemory(
63-
key_cache, k_cache_size_per_block_, "cpu:0", false, false);
64-
65-
if (!register_k_result.has_value()) {
66-
LOG(ERROR) << "Failed to register local memory for key cache: "
67-
<< toString(register_k_result.error());
68-
return false;
69-
}
70-
71-
void* value_cache = static_cast<char*>(
72-
host_kv_caches_->at(block).get_v_cache().data_ptr());
73-
74-
auto register_v_result = client_ptr_->RegisterLocalMemory(
75-
value_cache, v_cache_size_per_block_, "cpu:0", false, false);
76-
77-
if (!register_v_result.has_value()) {
78-
LOG(ERROR) << "Failed to register local memory for value cache: "
79-
<< toString(register_v_result.error());
58+
if (config_.total_size > 0 && config_.tensor_data != nullptr) {
59+
auto result = client_ptr_->RegisterLocalMemory(
60+
config_.tensor_data, config_.total_size, "cpu:0", false, false);
61+
if (!result.has_value()) {
62+
LOG(ERROR) << "Failed to register local memory: "
63+
<< toString(result.error());
8064
return false;
8165
}
66+
} else {
67+
LOG(FATAL) << "rdma must RegisterLocalMemory, but got register size: "
68+
<< config_.total_size
69+
<< ", and data ptr: " << uint64_t(config_.tensor_data);
8270
}
8371
}
8472
is_initialized_ = true;

xllm/core/framework/kv_cache/kv_cache_store.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ struct StoreConfig {
1919
std::string master_server_address = "";
2020
int replica_num = 1;
2121
uint32_t tp_rank = 0;
22+
size_t total_size = 0;
23+
void* tensor_data = nullptr;
2224
};
2325

2426
class KVCacheStore {

xllm/core/runtime/worker_impl.cpp

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,13 @@ WorkerImpl::WorkerImpl(const ParallelArgs& parallel_args,
7575
threadpool_.schedule([this]() mutable { device_.set_device(); });
7676
for (int i = 0; i < h2d_threadpool_.size(); i++) {
7777
h2d_threadpool_.schedule_with_tid(
78-
[this]() mutable {
79-
device_.set_device();
80-
h2d_stream_[std::this_thread::get_id()] =
81-
device_.get_stream_from_pool(TIMEOUT_MS);
82-
},
83-
i);
78+
[this]() mutable { device_.set_device(); }, i);
79+
copy_stream_.enqueue(device_.get_stream_from_pool(TIMEOUT_MS));
8480
}
8581
for (int i = 0; i < d2h_threadpool_.size(); i++) {
8682
d2h_threadpool_.schedule_with_tid(
87-
[this]() mutable {
88-
device_.set_device();
89-
d2h_stream_[std::this_thread::get_id()] =
90-
device_.get_stream_from_pool(TIMEOUT_MS);
91-
},
92-
i);
83+
[this]() mutable { device_.set_device(); }, i);
84+
copy_stream_.enqueue(device_.get_stream_from_pool(TIMEOUT_MS));
9385
}
9486

9587
prepare_stream_ = device_.get_stream_from_pool();
@@ -152,18 +144,9 @@ bool WorkerImpl::allocate_host_kv_cache(
152144
host_kv_cache_shape[1][0] = num_layers;
153145

154146
// create a KVCache shape: block_size * [layers, token, head, dim]
155-
host_kv_caches_.reserve(host_bolck_size);
147+
aligned_tensor_creater_ = std::make_unique<AlignedTensorCreater>(
148+
host_kv_cache_shape, dtype_, host_bolck_size, &host_kv_caches_);
156149

157-
for (int64_t i = 0; i < host_bolck_size; ++i) {
158-
torch::Tensor key_cache, value_cache;
159-
key_cache = torch::empty(host_kv_cache_shape[0],
160-
torch::dtype(dtype_).device(torch::kCPU))
161-
.pin_memory();
162-
value_cache = torch::empty(host_kv_cache_shape[1],
163-
torch::dtype(dtype_).device(torch::kCPU))
164-
.pin_memory();
165-
host_kv_caches_.emplace_back(key_cache, value_cache);
166-
}
167150
LOG(INFO) << "Initializing host kv block size: " << host_bolck_size;
168151

169152
int32_t device_id = device_.index();
@@ -188,6 +171,8 @@ bool WorkerImpl::allocate_host_kv_cache(
188171
config.tp_rank = options_.dp_size() > 1
189172
? options_.node_rank() % options_.dp_size()
190173
: options_.node_rank();
174+
config.total_size = aligned_tensor_creater_->get_total_size();
175+
config.tensor_data = aligned_tensor_creater_->get_base_ptr();
191176

192177
if (!KVCacheStore::get_instance().init(config, &host_kv_caches_)) {
193178
LOG(ERROR) << "Init KVCacheStore fail!";
@@ -805,9 +790,6 @@ uint32_t WorkerImpl::offload_kv_blocks(
805790

806791
bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
807792
#if defined(USE_NPU)
808-
CHECK(d2h_stream_.count(std::this_thread::get_id()) != 0)
809-
<< "WorkerImpl::d2h_batch_copy can only be called in d2h_threadpool_.";
810-
811793
const int64_t num_layers = context_.get_model_args().n_layers();
812794
uint32_t num_batches = block_transfer_info.size() * num_layers * 2;
813795
void** srcs = new void*[num_batches];
@@ -840,8 +822,9 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
840822
}
841823
}
842824

843-
c10::StreamGuard streamGuard =
844-
d2h_stream_[std::this_thread::get_id()]->set_stream_guard();
825+
std::unique_ptr<Stream> stream;
826+
copy_stream_.wait_dequeue(stream);
827+
c10::StreamGuard streamGuard = stream->set_stream_guard();
845828

846829
// TODO(kangmeng): change to async API
847830
aclError ret = aclrtMemcpyBatch(dsts,
@@ -856,14 +839,18 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
856839
if (ret != 0 || fail_index != SIZE_MAX) {
857840
LOG(ERROR) << "aclrtMemcpyBatch error: " << ret
858841
<< ", fail_index:" << fail_index;
842+
copy_stream_.enqueue(std::move(stream));
859843
return false;
860844
}
861845

862-
if (d2h_stream_[std::this_thread::get_id()]->synchronize() != 0) {
846+
if (stream->synchronize() != 0) {
863847
LOG(ERROR) << "d2h_batch_copy timeout!";
848+
copy_stream_.enqueue(std::move(stream));
864849
return false;
865850
}
866851

852+
copy_stream_.enqueue(std::move(stream));
853+
867854
delete[] dsts;
868855
delete[] srcs;
869856
delete[] copy_size;
@@ -875,8 +862,6 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
875862
bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
876863
Slice<BlockTransferInfo>& block_transfer_info) {
877864
#if defined(USE_NPU)
878-
CHECK(h2d_stream_.count(std::this_thread::get_id()) != 0)
879-
<< "WorkerImpl::h2d_batch_copy can only be called in h2d_threadpool_.";
880865
CHECK(block_transfer_info.size() < BATCH_COPY_MAX_SIZE / 2)
881866
<< "h2d_batch_copy support copy blocks less than "
882867
<< BATCH_COPY_MAX_SIZE / 2 << ", but got " << block_transfer_info.size();
@@ -903,9 +888,10 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
903888
aclrtMemcpyBatchAttr attrs[1] = {h2d_attrs_};
904889
size_t attrs_indexes[1] = {0};
905890

906-
c10::StreamGuard streamGuard =
907-
h2d_stream_[std::this_thread::get_id()]->set_stream_guard();
908-
auto stream = h2d_stream_[std::this_thread::get_id()]->get_stream()->stream();
891+
std::unique_ptr<Stream> stream;
892+
copy_stream_.wait_dequeue(stream);
893+
c10::StreamGuard streamGuard = stream->set_stream_guard();
894+
909895
aclError ret = 0;
910896

911897
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
@@ -946,7 +932,7 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
946932
LOG(ERROR) << "aclrtMemcpyBatch error: " << ret
947933
<< ", fail_index:" << fail_index;
948934
} else {
949-
ret = aclrtRecordEvent(*event, stream);
935+
ret = aclrtRecordEvent(*event, stream->get_stream()->stream());
950936
if (ret != 0) {
951937
LOG(ERROR) << "aclrtRecordEvent error: " << ret;
952938
}
@@ -955,10 +941,12 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
955941
if (ret != 0) break;
956942
}
957943

958-
if (h2d_stream_[std::this_thread::get_id()]->synchronize() != 0) {
944+
if (stream->synchronize() != 0) {
959945
LOG(ERROR) << "h2d_batch_copy timeout!";
946+
copy_stream_.enqueue(std::move(stream));
960947
return false;
961948
}
949+
copy_stream_.enqueue(std::move(stream));
962950

963951
delete[] dsts;
964952
delete[] srcs;
@@ -1026,4 +1014,68 @@ uint32_t WorkerImpl::prefetch_from_storage(
10261014
.get();
10271015
}
10281016

1017+
AlignedTensorCreater::AlignedTensorCreater(
1018+
const std::vector<std::vector<int64_t>>& tensor_shapes,
1019+
const torch::ScalarType dtype,
1020+
const uint32_t num_tensors,
1021+
std::vector<xllm::KVCache>* tensors) {
1022+
CHECK(tensor_shapes.size() == 2)
1023+
<< "tensor_shapes.size() must equal to 2, but got "
1024+
<< tensor_shapes.size();
1025+
1026+
int64_t elements_per_k_tensor = 1;
1027+
int64_t elements_per_v_tensor = 1;
1028+
1029+
for (auto dim : tensor_shapes[0]) {
1030+
elements_per_k_tensor *= dim;
1031+
}
1032+
for (auto dim : tensor_shapes[1]) {
1033+
elements_per_v_tensor *= dim;
1034+
}
1035+
1036+
size_t element_size = torch::elementSize(dtype);
1037+
size_t bytes_per_k_tensor = elements_per_k_tensor * element_size;
1038+
size_t bytes_per_v_tensor = elements_per_v_tensor * element_size;
1039+
size_t page_size = sysconf(_SC_PAGESIZE);
1040+
total_size_ = num_tensors * (bytes_per_k_tensor + bytes_per_v_tensor);
1041+
total_size_ = ((total_size_ + page_size - 1) / page_size) * page_size;
1042+
1043+
base_ptr_ = mmap(nullptr,
1044+
total_size_,
1045+
PROT_READ | PROT_WRITE,
1046+
MAP_PRIVATE | MAP_ANONYMOUS,
1047+
-1,
1048+
0);
1049+
1050+
if (base_ptr_ == MAP_FAILED) {
1051+
LOG(FATAL) << "Failed to allocate aligned memory pool!";
1052+
}
1053+
1054+
if (mlock(base_ptr_, total_size_) != 0) {
1055+
munmap(base_ptr_, total_size_);
1056+
LOG(FATAL) << "Failed to lock memory pool!";
1057+
}
1058+
1059+
size_t current_offset = 0;
1060+
auto options = torch::TensorOptions().dtype(dtype).device(torch::kCPU);
1061+
tensors->reserve(num_tensors);
1062+
1063+
for (size_t i = 0; i < num_tensors; ++i) {
1064+
void* k_tensor_ptr = static_cast<char*>(base_ptr_) + current_offset;
1065+
torch::Tensor k_tensor =
1066+
torch::from_blob(k_tensor_ptr, tensor_shapes[0], options);
1067+
current_offset += bytes_per_k_tensor;
1068+
1069+
void* v_tensor_ptr = static_cast<char*>(base_ptr_) + current_offset;
1070+
torch::Tensor v_tensor =
1071+
torch::from_blob(v_tensor_ptr, tensor_shapes[1], options);
1072+
current_offset += bytes_per_v_tensor;
1073+
1074+
tensors->emplace_back(k_tensor, v_tensor);
1075+
}
1076+
1077+
LOG(INFO) << "Page aligned: "
1078+
<< ((uintptr_t)base_ptr_ % page_size == 0 ? "YES" : "NO");
1079+
}
1080+
10291081
} // namespace xllm

xllm/core/runtime/worker_impl.h

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#pragma once
1717

1818
#include <folly/futures/Future.h>
19+
#include <sys/mman.h>
1920
#include <torch/torch.h>
2021

2122
#include <memory>
@@ -45,6 +46,8 @@ limitations under the License.
4546

4647
namespace xllm {
4748

49+
class AlignedTensorCreater;
50+
4851
class WorkerImpl {
4952
public:
5053
enum Status : int8_t {
@@ -214,11 +217,8 @@ class WorkerImpl {
214217
ThreadPool d2h_threadpool_{5};
215218
ThreadPool batchget_threadpool_{5};
216219
ThreadPool batchput_threadpool_{2};
217-
// copy streams
218-
// only can be used in h2d_threadpool_
219-
std::unordered_map<std::thread::id, std::unique_ptr<Stream>> h2d_stream_;
220-
// only can be used in d2h_threadpool_
221-
std::unordered_map<std::thread::id, std::unique_ptr<Stream>> d2h_stream_;
220+
// copy streams only can be used in h2d_threadpool_ and d2h_threadpool_
221+
moodycamel::BlockingConcurrentQueue<std::unique_ptr<Stream>> copy_stream_;
222222

223223
// dtype of the model
224224
torch::ScalarType dtype_;
@@ -237,6 +237,7 @@ class WorkerImpl {
237237
// kv caches
238238
std::vector<xllm::KVCache> kv_caches_;
239239
std::vector<xllm::KVCache> host_kv_caches_;
240+
std::unique_ptr<AlignedTensorCreater> aligned_tensor_creater_;
240241

241242
// causal LM model
242243
std::unique_ptr<CausalLM> model_;
@@ -277,4 +278,26 @@ class WorkerImpl {
277278
layer_wise_load_synchronizer_;
278279
};
279280

281+
class AlignedTensorCreater {
282+
private:
283+
void* base_ptr_;
284+
size_t total_size_;
285+
286+
public:
287+
AlignedTensorCreater(const std::vector<std::vector<int64_t>>& tensor_shapes,
288+
const torch::ScalarType dtype,
289+
const uint32_t num_tensors,
290+
std::vector<xllm::KVCache>* tensors);
291+
292+
~AlignedTensorCreater() {
293+
if (base_ptr_ != nullptr) {
294+
munlock(base_ptr_, total_size_);
295+
munmap(base_ptr_, total_size_);
296+
}
297+
}
298+
299+
void* get_base_ptr() const { return base_ptr_; }
300+
size_t get_total_size() const { return total_size_; }
301+
};
302+
280303
} // namespace xllm

0 commit comments

Comments
 (0)