@@ -75,21 +75,13 @@ WorkerImpl::WorkerImpl(const ParallelArgs& parallel_args,
7575 threadpool_.schedule ([this ]() mutable { device_.set_device (); });
7676 for (int i = 0 ; i < h2d_threadpool_.size (); i++) {
7777 h2d_threadpool_.schedule_with_tid (
78- [this ]() mutable {
79- device_.set_device ();
80- h2d_stream_[std::this_thread::get_id ()] =
81- device_.get_stream_from_pool (TIMEOUT_MS);
82- },
83- i);
78+ [this ]() mutable { device_.set_device (); }, i);
79+ copy_stream_.enqueue (device_.get_stream_from_pool (TIMEOUT_MS));
8480 }
8581 for (int i = 0 ; i < d2h_threadpool_.size (); i++) {
8682 d2h_threadpool_.schedule_with_tid (
87- [this ]() mutable {
88- device_.set_device ();
89- d2h_stream_[std::this_thread::get_id ()] =
90- device_.get_stream_from_pool (TIMEOUT_MS);
91- },
92- i);
83+ [this ]() mutable { device_.set_device (); }, i);
84+ copy_stream_.enqueue (device_.get_stream_from_pool (TIMEOUT_MS));
9385 }
9486
9587 prepare_stream_ = device_.get_stream_from_pool ();
@@ -152,18 +144,9 @@ bool WorkerImpl::allocate_host_kv_cache(
152144 host_kv_cache_shape[1 ][0 ] = num_layers;
153145
154146 // create a KVCache shape: block_size * [layers, token, head, dim]
155- host_kv_caches_.reserve (host_bolck_size);
147+ aligned_tensor_creater_ = std::make_unique<AlignedTensorCreater>(
148+ host_kv_cache_shape, dtype_, host_bolck_size, &host_kv_caches_);
156149
157- for (int64_t i = 0 ; i < host_bolck_size; ++i) {
158- torch::Tensor key_cache, value_cache;
159- key_cache = torch::empty (host_kv_cache_shape[0 ],
160- torch::dtype (dtype_).device (torch::kCPU ))
161- .pin_memory ();
162- value_cache = torch::empty (host_kv_cache_shape[1 ],
163- torch::dtype (dtype_).device (torch::kCPU ))
164- .pin_memory ();
165- host_kv_caches_.emplace_back (key_cache, value_cache);
166- }
167150 LOG (INFO) << " Initializing host kv block size: " << host_bolck_size;
168151
169152 int32_t device_id = device_.index ();
@@ -188,6 +171,8 @@ bool WorkerImpl::allocate_host_kv_cache(
188171 config.tp_rank = options_.dp_size () > 1
189172 ? options_.node_rank () % options_.dp_size ()
190173 : options_.node_rank ();
174+ config.total_size = aligned_tensor_creater_->get_total_size ();
175+ config.tensor_data = aligned_tensor_creater_->get_base_ptr ();
191176
192177 if (!KVCacheStore::get_instance ().init (config, &host_kv_caches_)) {
193178 LOG (ERROR) << " Init KVCacheStore fail!" ;
@@ -805,9 +790,6 @@ uint32_t WorkerImpl::offload_kv_blocks(
805790
806791bool WorkerImpl::d2h_batch_copy (Slice<BlockTransferInfo>& block_transfer_info) {
807792#if defined(USE_NPU)
808- CHECK (d2h_stream_.count (std::this_thread::get_id ()) != 0 )
809- << " WorkerImpl::d2h_batch_copy can only be called in d2h_threadpool_." ;
810-
811793 const int64_t num_layers = context_.get_model_args ().n_layers ();
812794 uint32_t num_batches = block_transfer_info.size () * num_layers * 2 ;
813795 void ** srcs = new void *[num_batches];
@@ -840,8 +822,9 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
840822 }
841823 }
842824
843- c10::StreamGuard streamGuard =
844- d2h_stream_[std::this_thread::get_id ()]->set_stream_guard ();
825+ std::unique_ptr<Stream> stream;
826+ copy_stream_.wait_dequeue (stream);
827+ c10::StreamGuard streamGuard = stream->set_stream_guard ();
845828
846829 // TODO(kangmeng): change to async API
847830 aclError ret = aclrtMemcpyBatch (dsts,
@@ -856,14 +839,18 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
856839 if (ret != 0 || fail_index != SIZE_MAX) {
857840 LOG (ERROR) << " aclrtMemcpyBatch error: " << ret
858841 << " , fail_index:" << fail_index;
842+ copy_stream_.enqueue (std::move (stream));
859843 return false ;
860844 }
861845
862- if (d2h_stream_[ std::this_thread::get_id ()] ->synchronize () != 0 ) {
846+ if (stream ->synchronize () != 0 ) {
863847 LOG (ERROR) << " d2h_batch_copy timeout!" ;
848+ copy_stream_.enqueue (std::move (stream));
864849 return false ;
865850 }
866851
852+ copy_stream_.enqueue (std::move (stream));
853+
867854 delete[] dsts;
868855 delete[] srcs;
869856 delete[] copy_size;
@@ -875,8 +862,6 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
875862bool WorkerImpl::h2d_batch_copy (const uint64_t batch_id,
876863 Slice<BlockTransferInfo>& block_transfer_info) {
877864#if defined(USE_NPU)
878- CHECK (h2d_stream_.count (std::this_thread::get_id ()) != 0 )
879- << " WorkerImpl::h2d_batch_copy can only be called in h2d_threadpool_." ;
880865 CHECK (block_transfer_info.size () < BATCH_COPY_MAX_SIZE / 2 )
881866 << " h2d_batch_copy support copy blocks less than "
882867 << BATCH_COPY_MAX_SIZE / 2 << " , but got " << block_transfer_info.size ();
@@ -903,9 +888,10 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
903888 aclrtMemcpyBatchAttr attrs[1 ] = {h2d_attrs_};
904889 size_t attrs_indexes[1 ] = {0 };
905890
906- c10::StreamGuard streamGuard =
907- h2d_stream_[std::this_thread::get_id ()]->set_stream_guard ();
908- auto stream = h2d_stream_[std::this_thread::get_id ()]->get_stream ()->stream ();
891+ std::unique_ptr<Stream> stream;
892+ copy_stream_.wait_dequeue (stream);
893+ c10::StreamGuard streamGuard = stream->set_stream_guard ();
894+
909895 aclError ret = 0 ;
910896
911897 for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
@@ -946,7 +932,7 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
946932 LOG (ERROR) << " aclrtMemcpyBatch error: " << ret
947933 << " , fail_index:" << fail_index;
948934 } else {
949- ret = aclrtRecordEvent (*event, stream);
935+ ret = aclrtRecordEvent (*event, stream-> get_stream ()-> stream () );
950936 if (ret != 0 ) {
951937 LOG (ERROR) << " aclrtRecordEvent error: " << ret;
952938 }
@@ -955,10 +941,12 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
955941 if (ret != 0 ) break ;
956942 }
957943
958- if (h2d_stream_[ std::this_thread::get_id ()] ->synchronize () != 0 ) {
944+ if (stream ->synchronize () != 0 ) {
959945 LOG (ERROR) << " h2d_batch_copy timeout!" ;
946+ copy_stream_.enqueue (std::move (stream));
960947 return false ;
961948 }
949+ copy_stream_.enqueue (std::move (stream));
962950
963951 delete[] dsts;
964952 delete[] srcs;
@@ -1026,4 +1014,68 @@ uint32_t WorkerImpl::prefetch_from_storage(
10261014 .get ();
10271015}
10281016
1017+ AlignedTensorCreater::AlignedTensorCreater (
1018+ const std::vector<std::vector<int64_t >>& tensor_shapes,
1019+ const torch::ScalarType dtype,
1020+ const uint32_t num_tensors,
1021+ std::vector<xllm::KVCache>* tensors) {
1022+ CHECK (tensor_shapes.size () == 2 )
1023+ << " tensor_shapes.size() must equal to 2, but got "
1024+ << tensor_shapes.size ();
1025+
1026+ int64_t elements_per_k_tensor = 1 ;
1027+ int64_t elements_per_v_tensor = 1 ;
1028+
1029+ for (auto dim : tensor_shapes[0 ]) {
1030+ elements_per_k_tensor *= dim;
1031+ }
1032+ for (auto dim : tensor_shapes[1 ]) {
1033+ elements_per_v_tensor *= dim;
1034+ }
1035+
1036+ size_t element_size = torch::elementSize (dtype);
1037+ size_t bytes_per_k_tensor = elements_per_k_tensor * element_size;
1038+ size_t bytes_per_v_tensor = elements_per_v_tensor * element_size;
1039+ size_t page_size = sysconf (_SC_PAGESIZE);
1040+ total_size_ = num_tensors * (bytes_per_k_tensor + bytes_per_v_tensor);
1041+ total_size_ = ((total_size_ + page_size - 1 ) / page_size) * page_size;
1042+
1043+ base_ptr_ = mmap (nullptr ,
1044+ total_size_,
1045+ PROT_READ | PROT_WRITE,
1046+ MAP_PRIVATE | MAP_ANONYMOUS,
1047+ -1 ,
1048+ 0 );
1049+
1050+ if (base_ptr_ == MAP_FAILED) {
1051+ LOG (FATAL) << " Failed to allocate aligned memory pool!" ;
1052+ }
1053+
1054+ if (mlock (base_ptr_, total_size_) != 0 ) {
1055+ munmap (base_ptr_, total_size_);
1056+ LOG (FATAL) << " Failed to lock memory pool!" ;
1057+ }
1058+
1059+ size_t current_offset = 0 ;
1060+ auto options = torch::TensorOptions ().dtype (dtype).device (torch::kCPU );
1061+ tensors->reserve (num_tensors);
1062+
1063+ for (size_t i = 0 ; i < num_tensors; ++i) {
1064+ void * k_tensor_ptr = static_cast <char *>(base_ptr_) + current_offset;
1065+ torch::Tensor k_tensor =
1066+ torch::from_blob (k_tensor_ptr, tensor_shapes[0 ], options);
1067+ current_offset += bytes_per_k_tensor;
1068+
1069+ void * v_tensor_ptr = static_cast <char *>(base_ptr_) + current_offset;
1070+ torch::Tensor v_tensor =
1071+ torch::from_blob (v_tensor_ptr, tensor_shapes[1 ], options);
1072+ current_offset += bytes_per_v_tensor;
1073+
1074+ tensors->emplace_back (k_tensor, v_tensor);
1075+ }
1076+
1077+ LOG (INFO) << " Page aligned: "
1078+ << ((uintptr_t )base_ptr_ % page_size == 0 ? " YES" : " NO" );
1079+ }
1080+
10291081} // namespace xllm
0 commit comments