@@ -152,18 +152,9 @@ bool WorkerImpl::allocate_host_kv_cache(
152152 host_kv_cache_shape[1 ][0 ] = num_layers;
153153
154154 // create a KVCache shape: block_size * [layers, token, head, dim]
155- host_kv_caches_.reserve (host_bolck_size);
155+ aligned_tensor_creater_ = std::make_unique<AlignedTensorCreater>(
156+ host_kv_cache_shape, dtype_, host_bolck_size, &host_kv_caches_);
156157
157- for (int64_t i = 0 ; i < host_bolck_size; ++i) {
158- torch::Tensor key_cache, value_cache;
159- key_cache = torch::empty (host_kv_cache_shape[0 ],
160- torch::dtype (dtype_).device (torch::kCPU ))
161- .pin_memory ();
162- value_cache = torch::empty (host_kv_cache_shape[1 ],
163- torch::dtype (dtype_).device (torch::kCPU ))
164- .pin_memory ();
165- host_kv_caches_.emplace_back (key_cache, value_cache);
166- }
167158 LOG (INFO) << " Initializing host kv block size: " << host_bolck_size;
168159
169160 int32_t device_id = device_.index ();
@@ -188,6 +179,8 @@ bool WorkerImpl::allocate_host_kv_cache(
188179 config.tp_rank = options_.dp_size () > 1
189180 ? options_.node_rank () % options_.dp_size ()
190181 : options_.node_rank ();
182+ config.total_size = aligned_tensor_creater_->get_total_size ();
183+ config.tensor_data = aligned_tensor_creater_->get_base_ptr ();
191184
192185 if (!KVCacheStore::get_instance ().init (config, &host_kv_caches_)) {
193186 LOG (ERROR) << " Init KVCacheStore fail!" ;
@@ -1026,4 +1019,68 @@ uint32_t WorkerImpl::prefetch_from_storage(
10261019 .get ();
10271020}
10281021
1022+ AlignedTensorCreater::AlignedTensorCreater (
1023+ const std::vector<std::vector<int64_t >>& tensor_shapes,
1024+ const torch::ScalarType dtype,
1025+ const uint32_t num_tensors,
1026+ std::vector<xllm::KVCache>* tensors) {
1027+ CHECK (tensor_shapes.size () == 2 )
1028+ << " tensor_shapes.size() must equal to 2, but got "
1029+ << tensor_shapes.size ();
1030+
1031+ int64_t elements_per_k_tensor = 1 ;
1032+ int64_t elements_per_v_tensor = 1 ;
1033+
1034+ for (auto dim : tensor_shapes[0 ]) {
1035+ elements_per_k_tensor *= dim;
1036+ }
1037+ for (auto dim : tensor_shapes[1 ]) {
1038+ elements_per_v_tensor *= dim;
1039+ }
1040+
1041+ size_t element_size = torch::elementSize (dtype);
1042+ size_t bytes_per_k_tensor = elements_per_k_tensor * element_size;
1043+ size_t bytes_per_v_tensor = elements_per_v_tensor * element_size;
1044+ size_t page_size = sysconf (_SC_PAGESIZE);
1045+ total_size_ = num_tensors * (bytes_per_k_tensor + bytes_per_v_tensor);
1046+ total_size_ = ((total_size_ + page_size - 1 ) / page_size) * page_size;
1047+
1048+ base_ptr_ = mmap (nullptr ,
1049+ total_size_,
1050+ PROT_READ | PROT_WRITE,
1051+ MAP_PRIVATE | MAP_ANONYMOUS,
1052+ -1 ,
1053+ 0 );
1054+
1055+ if (base_ptr_ == MAP_FAILED) {
1056+ LOG (FATAL) << " Failed to allocate aligned memory pool!" ;
1057+ }
1058+
1059+ if (mlock (base_ptr_, total_size_) != 0 ) {
1060+ munmap (base_ptr_, total_size_);
1061+ LOG (FATAL) << " Failed to lock memory pool!" ;
1062+ }
1063+
1064+ size_t current_offset = 0 ;
1065+ auto options = torch::TensorOptions ().dtype (dtype).device (torch::kCPU );
1066+ tensors->reserve (num_tensors);
1067+
1068+ for (size_t i = 0 ; i < num_tensors; ++i) {
1069+ void * k_tensor_ptr = static_cast <char *>(base_ptr_) + current_offset;
1070+ torch::Tensor k_tensor =
1071+ torch::from_blob (k_tensor_ptr, tensor_shapes[0 ], options);
1072+ current_offset += bytes_per_k_tensor;
1073+
1074+ void * v_tensor_ptr = static_cast <char *>(base_ptr_) + current_offset;
1075+ torch::Tensor v_tensor =
1076+ torch::from_blob (v_tensor_ptr, tensor_shapes[1 ], options);
1077+ current_offset += bytes_per_v_tensor;
1078+
1079+ tensors->emplace_back (k_tensor, v_tensor);
1080+ }
1081+
1082+ LOG (INFO) << " Page aligned: "
1083+ << ((uintptr_t )base_ptr_ % page_size == 0 ? " YES" : " NO" );
1084+ }
1085+
10291086} // namespace xllm
0 commit comments