-
Notifications
You must be signed in to change notification settings - Fork 77
feat: optimize kv cache load/offload. #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
17797ce to
034f86e
Compare
| return torch::Tensor(); | ||
| } | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a TODO tag, MTP need more support.
7cd5bd4 to
d4446aa
Compare
RobbieLeung
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| public: | ||
| ~ServerStreamHandler() { | ||
| if (!promise_set_.exchange(true)) { | ||
| try { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use try catch here?
| std::unique_ptr<std::thread> polling_thread_; | ||
|
|
||
| std::unique_ptr<ThreadPool> threadpool_; | ||
| ThreadPool copy_threadpool_{5}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why 5 threads? ??
64f071c to
eee6a80
Compare
eee6a80 to
90457da
Compare
90457da to
b996af0
Compare
| "", | ||
| "The address of the kv cache store metadata service."); | ||
|
|
||
| DEFINE_string(store_local_hostname, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the different between store_metadata_server and store_local_hostname.
| pb_cache.set_dst_block_id(info.dst_block_id); | ||
| pb_cache.set_hash_key(info.hash_key, MURMUR_HASH3_VALUE_LEN); | ||
|
|
||
| *pb_block_transfer_info->mutable_transfer_infos()->Add() = pb_cache; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
*pb_block_transfer_info->mutable_transfer_infos()->Add() = std::move(pb_cache);
| uint8_t* hash_key = nullptr; | ||
|
|
||
| CacheBlockInfo() {} | ||
| enum class TransferType : uint8_t { G2H = 0, H2D = 1, D2G = 2 }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we add some comments for these types :)
|
|
||
| if (success_cnt != current_slice.size() || | ||
| i + stream_copy_batch_size_ >= transfer_slice.size()) { | ||
| is_completed = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emmm... Does the code here indicate a prefetch failure?
| } | ||
| } | ||
| if (is_completed) { | ||
| close_future.wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: If is_completed was set to false above, does that mean we no longer need to wait() on close_future here?
how brpc to handle stream_handler in this case
And by the way, how can we ensure that multiple batches are delivered in order or received in order?
|
|
||
| size_t PrefixCache::insert(const std::vector<Block>& blocks) { | ||
| std::vector<Murmur3Key> insert_keys; | ||
| return insert(blocks, &insert_keys); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of insert_keys, it seems not be used later
| int Stream::synchronize() const { | ||
| #if defined(USE_NPU) | ||
| return aclrtSynchronizeStream(stream_.stream()); | ||
| return aclrtSynchronizeStreamWithTimeout(stream_.stream(), timeout_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in which case we need timeout? and what happen if timeout.
| threadpool_.schedule([this]() mutable { device_.set_device(); }); | ||
| general_threadpool_.schedule([this]() mutable { device_.set_device(); }); | ||
| for (int i = 0; i < h2d_threadpool_.size(); i++) { | ||
| h2d_threadpool_.schedule_with_tid( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
threadpool's construction function can pass a init_func now, so we can build h2d_threadpool_ like this
h2d_threadpool_ = std::make_unique<ThreadPool>(
2, [this]() mutable { device_.set_device(); });
| } | ||
|
|
||
| uint32_t WorkerImpl::offload_kv_blocks( | ||
| const std::vector<BlockTransferInfo>& block_transfer_info) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Perhaps it would be best to abstract this code(and the code below) into a new class here.
| std::move(copy_out_blocks_async(input.input_params))); | ||
| { | ||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| if (layer_wise_load_synchronizer_.count(input.input_params.batch_id) != |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we dont use lock here ? just a suggestion
No description provided.