diff --git a/fbgemm_gpu/cmake/TbeInference.cmake b/fbgemm_gpu/cmake/TbeInference.cmake index 2c3fe235f3..4b75c5149b 100644 --- a/fbgemm_gpu/cmake/TbeInference.cmake +++ b/fbgemm_gpu/cmake/TbeInference.cmake @@ -22,6 +22,7 @@ gpu_cpp_library( ${FBGEMM_GPU}/src/split_embeddings_cache/lru_cache_populate_byte.cpp ${FBGEMM_GPU}/src/split_embeddings_cache/lxu_cache.cpp ${FBGEMM_GPU}/src/split_embeddings_cache/split_embeddings_cache_ops.cpp + ${FBGEMM_GPU}/src/split_embeddings_cache/raw_embedding_streamer.cpp GPU_SRCS ${FBGEMM_GPU}/src/split_embeddings_cache/lfu_cache_find.cu ${FBGEMM_GPU}/src/split_embeddings_cache/lfu_cache_populate.cu diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h new file mode 100644 index 0000000000..e44dc65025 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#ifdef FBGEMM_FBCODE +#include +#endif + +#include + +namespace fbgemm_gpu { + +struct StreamQueueItem { + at::Tensor indices; + at::Tensor weights; + at::Tensor count; + StreamQueueItem( + at::Tensor src_indices, + at::Tensor src_weights, + at::Tensor src_count) { + indices = std::move(src_indices); + weights = std::move(src_weights); + count = std::move(src_count); + } +}; + +class RawEmbeddingStreamer : public torch::jit::CustomClassHolder { + public: + explicit RawEmbeddingStreamer( + std::string unique_id, + bool enable_raw_embedding_streaming, + int64_t res_store_shards, + int64_t res_server_port, + std::vector table_names, + std::vector table_offsets, + const std::vector& table_sizes); + + virtual ~RawEmbeddingStreamer(); + + /// Stream out non-negative elements in and its paired embeddings + /// from for the first elements in the tensor. + /// It spins up a thread that will copy all 3 tensors to CPU and inject them + /// into the background queue which will be picked up by another set of thread + /// pools for streaming out to the thrift server (co-located on same host + /// now). + /// + /// This is used in cuda stream callback, which doesn't require to be + /// serialized with other callbacks, thus a separate thread is used to + /// maximize the overlapping with other callbacks. + /// + /// @param indices The 1D embedding index tensor, should skip on negative + /// value + /// @param weights The 2D tensor that each row(embeddings) is paired up with + /// relative element in + /// @param count A single element tensor that contains the number of indices + /// to be processed + /// @param blocking_tensor_copy whether to copy the tensors to be streamed in + /// a blocking manner + /// + /// @return None + void stream( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + bool require_tensor_copy, + bool blocking_tensor_copy = true); + +#ifdef FBGEMM_FBCODE + folly::coro::Task tensor_stream( + const at::Tensor& indices, + const at::Tensor& weights); + /* + * Copy the indices, weights and count tensors and enqueue them for + * asynchronous stream. + */ + void copy_and_enqueue_stream_tensors( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count); + + /* + * Join the stream tensor copy thread, make sure the thread is properly + * finished before creating new. + */ + void join_stream_tensor_copy_thread(); + + /* + * FOR TESTING: Join the weight stream thread, make sure the thread is + * properly finished for destruction and testing. + */ + void join_weights_stream_thread(); + // FOR TESTING: get queue size. + uint64_t get_weights_to_stream_queue_size(); +#endif + private: + std::atomic stop_{false}; + std::string unique_id_; + bool enable_raw_embedding_streaming_; + int64_t res_store_shards_; + int64_t res_server_port_; + std::vector table_names_; + std::vector table_offsets_; + at::Tensor table_sizes_; +#ifdef FBGEMM_FBCODE + std::unique_ptr weights_stream_thread_; + folly::UMPSCQueue weights_to_stream_queue_; + std::unique_ptr stream_tensor_copy_thread_; +#endif +}; + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp new file mode 100644 index 0000000000..ffee3ff672 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp @@ -0,0 +1,314 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef FBGEMM_FBCODE +#include +#include +#include +#include "aiplatform/gmpp/experimental/training_ps/gen-cpp2/TrainingParameterServerService.h" +#include "caffe2/torch/fb/distributed/wireSerializer/WireSerializer.h" +#include "servicerouter/client/cpp2/ClientParams.h" +#include "servicerouter/client/cpp2/ServiceRouter.h" +#include "torch/csrc/autograd/record_function_ops.h" +#include "torch/types.h" +#endif +#include "fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h" +#include "fbgemm_gpu/utils/dispatch_macros.h" + +namespace fbgemm_gpu { +namespace { + +#ifdef FBGEMM_FBCODE +/* + * Get the thrift client to the training parameter server service + * There is a destruction double free issue when wrapping the member + * variable under ifdef, and creating client is relatively cheap, so create this + * helper function to get the client just before sending requests. + */ +std::unique_ptr< + apache::thrift::Client> +get_res_client(int64_t res_server_port) { + auto& factory = facebook::servicerouter::cpp2::getClientFactory(); + auto& params = facebook::servicerouter::ClientParams().setSingleHost( + "::", res_server_port); + return factory.getSRClientUnique< + apache::thrift::Client>( + "realtime.delta.publish.esr", params); +} +#endif + +/// Read a scalar value from a tensor that is maybe a UVM tensor +/// Note that `tensor.item()` is not allowed on a UVM tensor in +/// PyTorch +inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Long + ? *(tensor.data_ptr()) + : *(tensor.data_ptr()); +} + +fbgemm_gpu::StreamQueueItem tensor_copy( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) { + auto num_sets = get_maybe_uvm_scalar(count); + auto new_indices = at::empty( + num_sets, at::TensorOptions().device(at::kCPU).dtype(indices.dtype())); + auto new_weights = at::empty( + {num_sets, weights.size(1)}, + at::TensorOptions().device(at::kCPU).dtype(weights.dtype())); + auto new_count = + at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( + weights.scalar_type(), "tensor_copy", [&] { + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "tensor_copy", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + auto new_indices_addr = new_indices.data_ptr(); + std::copy( + indices_addr, + indices_addr + num_sets, + new_indices_addr); // dst_start + + auto weights_addr = weights.data_ptr(); + auto new_weightss_addr = new_weights.data_ptr(); + std::copy( + weights_addr, + weights_addr + num_sets * weights.size(1), + new_weightss_addr); // dst_start + }); + }); + *new_count.data_ptr() = num_sets; + return fbgemm_gpu::StreamQueueItem{new_indices, new_weights, new_count}; +} + +} // namespace + +RawEmbeddingStreamer::RawEmbeddingStreamer( + std::string unique_id, + bool enable_raw_embedding_streaming, + int64_t res_store_shards, + int64_t res_server_port, + std::vector table_names, + std::vector table_offsets, + const std::vector& table_sizes) + : unique_id_(std::move(unique_id)), + enable_raw_embedding_streaming_(enable_raw_embedding_streaming), + res_store_shards_(res_store_shards), + res_server_port_(res_server_port), + table_names_(std::move(table_names)), + table_offsets_(std::move(table_offsets)), + table_sizes_(at::tensor(table_sizes)) { +#ifdef FBGEMM_FBCODE + if (enable_raw_embedding_streaming_) { + XLOG(INFO) << "[TBE_ID" << unique_id_ + << "] Raw embedding streaming enabled with res_server_port at" + << res_server_port; + // The first call to get the client is expensive, so eagerly get it here + auto _eager_client = get_res_client(res_server_port_); + + weights_stream_thread_ = std::make_unique([=, this] { + while (!stop_) { + auto stream_item_ptr = weights_to_stream_queue_.try_peek(); + if (!stream_item_ptr) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + continue; + } + if (stop_) { + return; + } + auto& indices = stream_item_ptr->indices; + auto& weights = stream_item_ptr->weights; + folly::stop_watch stop_watch; + folly::coro::blockingWait(tensor_stream(indices, weights)); + + weights_to_stream_queue_.dequeue(); + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ + << "] end stream queue size: " << weights_to_stream_queue_.size() + << " stream takes " << stop_watch.elapsed().count() << "ms"; + } + }); + } +#endif +} + +RawEmbeddingStreamer::~RawEmbeddingStreamer() { + stop_ = true; +#ifdef FBGEMM_FBCODE + if (enable_raw_embedding_streaming_) { + join_stream_tensor_copy_thread(); + join_weights_stream_thread(); + } +#endif +} + +void RawEmbeddingStreamer::stream( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + bool require_tensor_copy, + bool blocking_tensor_copy) { + if (!enable_raw_embedding_streaming_) { + return; + } +#ifdef FBGEMM_FBCODE + auto rec = torch::autograd::profiler::record_function_enter_new( + "## RawEmbeddingStreamer::stream_callback ##"); + if (!require_tensor_copy) { + StreamQueueItem stream_item(indices, weights, count); + weights_to_stream_queue_.enqueue(stream_item); + return; + } + if (blocking_tensor_copy) { + copy_and_enqueue_stream_tensors(indices, weights, count); + return; + } + // Make sure the previous thread is done before starting a new one + join_stream_tensor_copy_thread(); + // Cuda dispatches the host callbacks all in the same CPU thread. But the + // callbacks don't need to be serialized. + // So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA + // can continue executing other host callbacks, eg. get/evict. + stream_tensor_copy_thread_ = std::make_unique([=, this]() { + copy_and_enqueue_stream_tensors(indices, weights, count); + }); + rec->record.end(); +#endif +} + +#ifdef FBGEMM_FBCODE +folly::coro::Task RawEmbeddingStreamer::tensor_stream( + const at::Tensor& indices, + const at::Tensor& weights) { + using namespace ::aiplatform::gmpp::experimental::training_ps; + if (indices.size(0) != weights.size(0)) { + XLOG(ERR) << "[TBE_ID" << unique_id_ + << "] Indices and weights size mismatched " << indices.size(0) + << " " << weights.size(0); + co_return; + } + folly::stop_watch stop_watch; + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ + << "] send streaming request: indices = " << indices.size(0) + << ", weights = " << weights.size(0); + + auto biggest_idx = table_sizes_.index({table_sizes_.size(0) - 1}); + auto mask = + at::logical_and(indices >= 0, indices < biggest_idx).nonzero().squeeze(); + auto filtered_indices = indices.index_select(0, mask); + auto filtered_weights = weights.index_select(0, mask); + auto num_invalid_indices = indices.size(0) - filtered_indices.size(0); + if (num_invalid_indices > 0) { + XLOG(INFO) << "[TBE_ID" << unique_id_ + << "] number of invalid indices: " << num_invalid_indices; + } + // 1. Transform local row indices to embedding table global row indices + at::Tensor table_indices = + (at::searchsorted(table_sizes_, filtered_indices, false, true) - 1) + .to(torch::kInt8); + auto tb_ac = table_indices.accessor(); + auto indices_ac = filtered_indices.accessor(); + auto tb_sizes_ac = table_sizes_.accessor(); + std::vector global_indices(tb_ac.size(0), 0); + std::vector shard_indices(tb_ac.size(0), 0); + + for (int i = 0; i < tb_ac.size(0); ++i) { + auto tb_idx = tb_ac[i]; + global_indices[i] = + indices_ac[i] - tb_sizes_ac[tb_idx] + table_offsets_[tb_idx]; + // hash to shard + // if we do row range sharding, also shard here. + auto fqn = table_names_[tb_idx]; + auto hash_key = folly::to(fqn, global_indices[i]); + auto shard_id = + furcHash(hash_key.data(), hash_key.size(), res_store_shards_); + shard_indices[i] = shard_id; + } + auto global_indices_tensor = at::tensor(global_indices); + auto shard_indices_tensor = at::tensor(shard_indices); + auto total_rows = global_indices_tensor.size(0); + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ << "] hash and gloablize rows " << total_rows + << " in: " << stop_watch.elapsed().count() << "ms"; + stop_watch.reset(); + + auto res_client = get_res_client(res_server_port_); + // 2. Split by shards + for (int i = 0; i < res_store_shards_; ++i) { + auto shrad_mask = shard_indices_tensor.eq(i).nonzero().squeeze(); + auto table_indices_masked = table_indices.index_select(0, shrad_mask); + auto rows_in_shard = table_indices_masked.numel(); + if (rows_in_shard == 0) { + continue; + } + auto global_indices_masked = + global_indices_tensor.index_select(0, shrad_mask); + auto weights_masked = filtered_weights.index_select(0, shrad_mask); + + if (weights_masked.size(0) != rows_in_shard || + global_indices_masked.numel() != rows_in_shard) { + XLOG(ERR) + << "[TBE_ID" << unique_id_ + << "] don't send the request for size mismatched tensors table: " + << rows_in_shard << " weights: " << weights_masked.size(0) + << " global_indices: " << global_indices_masked.numel(); + continue; + } + SetEmbeddingsRequest req; + req.shardId() = i; + req.fqns() = table_names_; + + req.tableIndices() = + torch::distributed::wireDumpTensor(table_indices_masked); + req.rowIndices() = + torch::distributed::wireDumpTensor(global_indices_masked); + req.weights() = torch::distributed::wireDumpTensor(weights_masked); + co_await res_client->co_setEmbeddings(req); + } + co_return; +} + +void RawEmbeddingStreamer::copy_and_enqueue_stream_tensors( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) { + auto rec = torch::autograd::profiler::record_function_enter_new( + "## RawEmbeddingStreamer::copy_and_enqueue_stream_tensors ##"); + auto stream_item = tensor_copy(indices, weights, count); + weights_to_stream_queue_.enqueue(stream_item); + rec->record.end(); +} + +void RawEmbeddingStreamer::join_stream_tensor_copy_thread() { + auto rec = torch::autograd::profiler::record_function_enter_new( + "## RawEmbeddingStreamer::join_stream_tensor_copy_thread ##"); + if (stream_tensor_copy_thread_ != nullptr && + stream_tensor_copy_thread_->joinable()) { + stream_tensor_copy_thread_->join(); + } + rec->record.end(); +} + +void RawEmbeddingStreamer::join_weights_stream_thread() { + if (weights_stream_thread_ != nullptr && weights_stream_thread_->joinable()) { + stop_ = true; + weights_stream_thread_->join(); + } +} + +uint64_t RawEmbeddingStreamer::get_weights_to_stream_queue_size() { + return weights_to_stream_queue_.size(); +} +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp index 83d86fd097..89b1e9d720 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp @@ -7,6 +7,7 @@ */ #include "common.h" +#include "fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h" namespace { @@ -73,4 +74,39 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_META("lxu_cache_lookup", lxu_cache_lookup_meta); } +static auto raw_embedding_streamer = + torch::class_( + "fbgemm", + "RawEmbeddingStreamer") + .def( + torch::init< + std::string, + bool, + int64_t, + int64_t, + std::vector, + std::vector, + std::vector>(), + "", + { + torch::arg("unique_id") = 0, + torch::arg("enable_raw_embedding_streaming") = false, + torch::arg("res_store_shards") = 0, + torch::arg("res_server_port") = 0, + torch::arg("table_names") = torch::List(), + torch::arg("table_offsets") = torch::List(), + torch::arg("table_sizes") = torch::List(), + }) + .def( + "stream", + &fbgemm_gpu::RawEmbeddingStreamer::stream, + "", + { + torch::arg("indices"), + torch::arg("weights"), + torch::arg("count"), + torch::arg("require_tensor_copy"), + torch::arg("blocking_tensor_copy"), + }); + } // namespace diff --git a/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp b/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp new file mode 100644 index 0000000000..3b408f7a93 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp @@ -0,0 +1,265 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu/src/split_embeddings_cache:raw_embedding_streamer +#ifdef FBGEMM_FBCODE +#include +#include "aiplatform/gmpp/experimental/training_ps/gen-cpp2/TrainingParameterServerService.h" +#include "servicerouter/client/cpp2/mocks/MockSRClientFactory.h" +#include "thrift/lib/cpp2/util/ScopedServerInterfaceThread.h" +#endif + +using namespace ::testing; +using namespace fbgemm_gpu; +constexpr int64_t EMBEDDING_DIMENSION = 8; + +#ifdef FBGEMM_FBCODE +class MockTrainingParameterServerService + : public ::apache::thrift::ServiceHandler< + aiplatform::gmpp::experimental::training_ps:: + TrainingParameterServerService> { + public: + MOCK_METHOD( + folly::coro::Task>, + co_setEmbeddings, + (std::unique_ptr< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest>)); +}; +#endif + +static std::unique_ptr +getRawEmbeddingStreamer( + const std::string& unique_id, + bool enable_raw_embedding_streaming = false, + const std::vector& table_names = {}, + const std::vector& table_offsets = {}, + const std::vector& table_sizes = {}) { + return std::make_unique( + unique_id, + enable_raw_embedding_streaming, + 3, // res_store_shards + 0, // res_server_port + table_names, + table_offsets, + table_sizes); +} + +TEST(RawEmbeddingStreamerTest, TestConstructorAndDestructor) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_constructor", false, table_names, table_offsets, table_sizes); + EXPECT_NE(streamer, nullptr); +} + +TEST(RawEmbeddingStreamerTest, TestStreamWithoutStreaming) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_no_streaming", false, table_names, table_offsets, table_sizes); + + auto indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + auto count = at::tensor( + {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + // Should not crash when streaming is disabled + streamer->stream(indices, weights, count, true, true); +} + +#ifdef FBGEMM_FBCODE +TEST(RawEmbeddingStreamerTest, TestTensorStream) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_tensor_stream", true, table_names, table_offsets, table_sizes); + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + // Test with invalid indices - should not call service + auto invalid_indices = at::tensor( + {300, 301, 999}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {invalid_indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + EXPECT_CALL(*mock_service, co_setEmbeddings(_)).Times(0); + folly::coro::blockingWait(streamer->tensor_stream(invalid_indices, weights)); + + // Test with valid indices - should call service + auto valid_indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + weights = at::randn( + {valid_indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + EXPECT_CALL(*mock_service, co_setEmbeddings(_)) + .Times(3) // 3 shards with consistent hashing + .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke( + [](std::unique_ptr< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> + request) + -> folly::coro::Task< + std::unique_ptr> { + co_return std::make_unique< + aiplatform::gmpp::experimental::training_ps:: + SetEmbeddingsResponse>(); + })); + folly::coro::blockingWait(streamer->tensor_stream(valid_indices, weights)); +} + +TEST(RawEmbeddingStreamerTest, TestStreamWithCopy) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_stream_copy", true, table_names, table_offsets, table_sizes); + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + auto indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + auto count = at::tensor( + {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + // Stop the dequeue thread to get accurate queue size + streamer->join_weights_stream_thread(); + + // Test blocking tensor copy + streamer->stream(indices, weights, count, true, true); + EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); + + // Test non-blocking tensor copy + streamer->stream(indices, weights, count, true, false); + EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); + streamer->join_stream_tensor_copy_thread(); + EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 2); +} + +TEST(RawEmbeddingStreamerTest, TestStreamE2E) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + auto default_response = + [](std::unique_ptr< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> + request) + -> folly::coro::Task> { + co_return std::make_unique< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsResponse>(); + }; + + EXPECT_CALL(*mock_service, co_setEmbeddings(_)) + .Times(3) // 3 shards with consistent hashing + .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke(default_response)); + + auto streamer = getRawEmbeddingStreamer( + "test_stream_e2e", true, table_names, table_offsets, table_sizes); + + auto indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + auto count = at::tensor( + {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + streamer->stream(indices, weights, count, true, true); + // Make sure dequeue finished + std::this_thread::sleep_for(std::chrono::seconds(1)); + streamer->join_weights_stream_thread(); +} + +TEST(RawEmbeddingStreamerTest, TestMismatchedIndicesWeights) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_mismatch", true, table_names, table_offsets, table_sizes); + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + // Test with mismatched sizes - should not call service + auto indices = at::tensor( + {10, 2, 1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {5, EMBEDDING_DIMENSION}, // Different size than indices + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + + EXPECT_CALL(*mock_service, co_setEmbeddings(_)).Times(0); + folly::coro::blockingWait(streamer->tensor_stream(indices, weights)); +} +#endif diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 402633b583..55e548af0f 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -15,14 +15,6 @@ #include "kv_db_cuda_utils.h" #endif #include "torch/csrc/autograd/record_function_ops.h" -#ifdef FBGEMM_FBCODE -#include -#include "aiplatform/gmpp/experimental/training_ps/gen-cpp2/TrainingParameterServerService.h" -#include "caffe2/torch/fb/distributed/wireSerializer/WireSerializer.h" -#include "servicerouter/client/cpp2/ClientParams.h" -#include "servicerouter/client/cpp2/ServiceRouter.h" -#include "torch/types.h" -#endif namespace kv_db { @@ -36,28 +28,6 @@ inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) { ? *(tensor.data_ptr()) : *(tensor.data_ptr()); } - -#ifdef FBGEMM_FBCODE -/* - * Get the thrift client to the training parameter server service - * There is a destruction double free issue when wrapping the member - * variable under ifdef, and creating client is relatively cheap, so create this - * helper function to get the client just before sending requests. - */ -std::unique_ptr< - apache::thrift::Client> -get_res_client(int64_t res_server_port) { - auto& factory = facebook::servicerouter::cpp2::getClientFactory(); - auto& params = facebook::servicerouter::ClientParams().setSingleHost( - "::", res_server_port); - return factory.getSRClientUnique< - apache::thrift::Client>( - "realtime.delta.publish.esr", params); -} -#endif - }; // namespace QueueItem tensor_copy( @@ -118,12 +88,15 @@ EmbeddingKVDB::EmbeddingKVDB( max_D_(max_D), executor_tp_(std::make_unique(num_shards)), enable_async_update_(enable_async_update), - enable_raw_embedding_streaming_(enable_raw_embedding_streaming), - res_store_shards_(res_store_shards), - res_server_port_(res_server_port), - table_names_(std::move(table_names)), - table_offsets_(std::move(table_offsets)), - table_sizes_(at::tensor(table_sizes)) { + raw_embedding_streamer_( + std::make_unique( + std::to_string(unique_id), + enable_raw_embedding_streaming, + res_store_shards, + res_server_port, + std::move(table_names), + std::move(table_offsets), + table_sizes)) { CHECK(num_shards > 0); if (cache_size_gb > 0) { l2_cache::CacheLibCache::CacheConfig cache_config; @@ -139,8 +112,6 @@ EmbeddingKVDB::EmbeddingKVDB( XLOG(INFO) << "[TBE_ID" << unique_id_ << "] L2 created with " << num_shards_ << " shards, dimension:" << max_D_ << ", enable_async_update_:" << enable_async_update_ - << ", enable_raw_embedding_streaming_:" - << enable_raw_embedding_streaming_ << ", cache_size_gb:" << cache_size_gb; if (enable_async_update_) { @@ -167,38 +138,6 @@ EmbeddingKVDB::EmbeddingKVDB( } }); } -#ifdef FBGEMM_FBCODE - if (enable_raw_embedding_streaming_) { - XLOG(INFO) << "[TBE_ID" << unique_id_ - << "] Raw embedding streaming enabled with res_server_port at" - << res_server_port; - // The first call to get the client is expensive, so eagerly get it here - auto _eager_client = get_res_client(res_server_port_); - - weights_stream_thread_ = std::make_unique([=, this] { - while (!stop_) { - auto stream_item_ptr = weights_to_stream_queue_.try_peek(); - if (!stream_item_ptr) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - continue; - } - if (stop_) { - return; - } - auto& indices = stream_item_ptr->indices; - auto& weights = stream_item_ptr->weights; - folly::stop_watch stop_watch; - folly::coro::blockingWait(tensor_stream(indices, weights)); - - weights_to_stream_queue_.dequeue(); - XLOG_EVERY_MS(INFO, 60000) - << "[TBE_ID" << unique_id_ - << "] end stream queue size: " << weights_to_stream_queue_.size() - << " stream takes " << stop_watch.elapsed().count() << "ms"; - } - }); - } -#endif } EmbeddingKVDB::~EmbeddingKVDB() { @@ -206,141 +145,8 @@ EmbeddingKVDB::~EmbeddingKVDB() { if (enable_async_update_) { cache_filling_thread_->join(); } -#ifdef FBGEMM_FBCODE - if (enable_raw_embedding_streaming_) { - join_stream_tensor_copy_thread(); - join_weights_stream_thread(); - } -#endif -} - -#ifdef FBGEMM_FBCODE -folly::coro::Task EmbeddingKVDB::tensor_stream( - const at::Tensor& indices, - const at::Tensor& weights) { - using namespace ::aiplatform::gmpp::experimental::training_ps; - if (indices.size(0) != weights.size(0)) { - XLOG(ERR) << "[TBE_ID" << unique_id_ - << "] Indices and weights size mismatched " << indices.size(0) - << " " << weights.size(0); - co_return; - } - folly::stop_watch stop_watch; - XLOG_EVERY_MS(INFO, 60000) - << "[TBE_ID" << unique_id_ - << "] send streaming request: indices = " << indices.size(0) - << ", weights = " << weights.size(0); - - auto biggest_idx = table_sizes_.index({table_sizes_.size(0) - 1}); - auto mask = - at::logical_and(indices >= 0, indices < biggest_idx).nonzero().squeeze(); - auto filtered_indices = indices.index_select(0, mask); - auto filtered_weights = weights.index_select(0, mask); - auto num_invalid_indices = indices.size(0) - filtered_indices.size(0); - if (num_invalid_indices > 0) { - XLOG(INFO) << "[TBE_ID" << unique_id_ - << "] number of invalid indices: " << num_invalid_indices; - } - // 1. Transform local row indices to embedding table global row indices - at::Tensor table_indices = - (at::searchsorted(table_sizes_, filtered_indices, false, true) - 1) - .to(torch::kInt8); - auto tb_ac = table_indices.accessor(); - auto indices_ac = filtered_indices.accessor(); - auto tb_sizes_ac = table_sizes_.accessor(); - std::vector global_indices(tb_ac.size(0), 0); - std::vector shard_indices(tb_ac.size(0), 0); - - for (int i = 0; i < tb_ac.size(0); ++i) { - int tb_idx = tb_ac[i]; - global_indices[i] = - indices_ac[i] - tb_sizes_ac[tb_idx] + table_offsets_[tb_idx]; - // hash to shard - // if we do row range sharding, also shard here. - auto fqn = table_names_[tb_idx]; - auto hash_key = folly::to(fqn, global_indices[i]); - auto shard_id = - furcHash(hash_key.data(), hash_key.size(), res_store_shards_); - shard_indices[i] = shard_id; - } - auto global_indices_tensor = at::tensor(global_indices); - auto shard_indices_tensor = at::tensor(shard_indices); - auto total_rows = global_indices_tensor.size(0); - XLOG_EVERY_MS(INFO, 60000) - << "[TBE_ID" << unique_id_ << "] hash and gloablize rows " << total_rows - << " in: " << stop_watch.elapsed().count() << "ms"; - stop_watch.reset(); - - auto res_client = get_res_client(res_server_port_); - // 2. Split by shards - for (int i = 0; i < res_store_shards_; ++i) { - auto shrad_mask = shard_indices_tensor.eq(i).nonzero().squeeze(); - auto table_indices_masked = table_indices.index_select(0, shrad_mask); - auto rows_in_shard = table_indices_masked.numel(); - if (rows_in_shard == 0) { - continue; - } - auto global_indices_masked = - global_indices_tensor.index_select(0, shrad_mask); - auto weights_masked = filtered_weights.index_select(0, shrad_mask); - - if (weights_masked.size(0) != rows_in_shard || - global_indices_masked.numel() != rows_in_shard) { - XLOG(ERR) - << "[TBE_ID" << unique_id_ - << "] don't send the request for size mismatched tensors table: " - << rows_in_shard << " weights: " << weights_masked.size(0) - << " global_indices: " << global_indices_masked.numel(); - continue; - } - SetEmbeddingsRequest req; - req.shardId() = i; - req.fqns() = table_names_; - - req.tableIndices() = - torch::distributed::wireDumpTensor(table_indices_masked); - req.rowIndices() = - torch::distributed::wireDumpTensor(global_indices_masked); - req.weights() = torch::distributed::wireDumpTensor(weights_masked); - co_await res_client->co_setEmbeddings(req); - } - co_return; -} - -void EmbeddingKVDB::copy_and_enqueue_stream_tensors( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count) { - auto rec = torch::autograd::profiler::record_function_enter_new( - "## EmbeddingKVDB::copy_and_enqueue_stream_tensors ##"); - auto stream_item = - tensor_copy(indices, weights, count, kv_db::RocksdbWriteMode::STREAM); - weights_to_stream_queue_.enqueue(stream_item); - rec->record.end(); -} - -void EmbeddingKVDB::join_stream_tensor_copy_thread() { - auto rec = torch::autograd::profiler::record_function_enter_new( - "## EmbeddingKVDB::join_stream_tensor_copy_thread ##"); - if (stream_tensor_copy_thread_ != nullptr && - stream_tensor_copy_thread_->joinable()) { - stream_tensor_copy_thread_->join(); - } - rec->record.end(); -} - -void EmbeddingKVDB::join_weights_stream_thread() { - if (weights_stream_thread_ != nullptr && weights_stream_thread_->joinable()) { - stop_ = true; - weights_stream_thread_->join(); - } } -uint64_t EmbeddingKVDB::get_weights_to_stream_queue_size() { - return weights_to_stream_queue_.size(); -} -#endif - void EmbeddingKVDB::update_cache_and_storage( const at::Tensor& indices, const at::Tensor& weights, @@ -491,8 +297,14 @@ void EmbeddingKVDB::stream_cuda( check_tensor_type_consistency(indices, weights); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); - std::function* functor = new std::function( - [=]() { self->stream(indices, weights, count, blocking_tensor_copy); }); + std::function* functor = new std::function([=]() { + self->raw_embedding_streamer_->stream( + indices, + weights, + count, + true, /*require_tensor_copy*/ + blocking_tensor_copy); + }); AT_CUDA_CHECK(cudaStreamAddCallback( at::cuda::getCurrentCUDAStream(), kv_db_utils::cuda_callback_func, @@ -508,8 +320,9 @@ void EmbeddingKVDB::stream_sync_cuda() { "## EmbeddingKVDB::stream_sync_cuda ##"); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); - std::function* functor = new std::function( - [=]() { self->join_stream_tensor_copy_thread(); }); + std::function* functor = new std::function([=]() { + self->raw_embedding_streamer_->join_stream_tensor_copy_thread(); + }); AT_CUDA_CHECK(cudaStreamAddCallback( at::cuda::getCurrentCUDAStream(), kv_db_utils::cuda_callback_func, @@ -593,7 +406,6 @@ void EmbeddingKVDB::set( return; } CHECK_EQ(max_D_, weights.size(1)); - auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::set_callback ##"); // defer the L2 cache/rocksdb update to the background thread as it could @@ -692,32 +504,6 @@ void EmbeddingKVDB::get( rec->record.end(); } -void EmbeddingKVDB::stream( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count, - bool blocking_tensor_copy) { - if (!enable_raw_embedding_streaming_) { - return; - } - auto rec = torch::autograd::profiler::record_function_enter_new( - "## EmbeddingKVDB::stream_callback ##"); - if (blocking_tensor_copy) { - copy_and_enqueue_stream_tensors(indices, weights, count); - return; - } - // Make sure the previous thread is done before starting a new one - join_stream_tensor_copy_thread(); - // Cuda dispatches the host callbacks all in the same CPU thread. But the - // callbacks don't need to be serialized. - // So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA - // can continue executing other host callbacks, eg. get/evict. - stream_tensor_copy_thread_ = std::make_unique([=, this]() { - copy_and_enqueue_stream_tensors(indices, weights, count); - }); - rec->record.end(); -} - std::shared_ptr EmbeddingKVDB::get_cache( const at::Tensor& indices, const at::Tensor& count) { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 656cb61067..40911249f6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -42,6 +42,7 @@ #include #include "../dram_kv_embedding_cache/feature_evict.h" #include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h" +#include "fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h" #include "fbgemm_gpu/utils/dispatch_macros.h" namespace ssd { @@ -202,33 +203,6 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const at::Tensor& count, int64_t sleep_ms = 0); - /// Stream out non-negative elements in and its paired embeddings - /// from for the first elements in the tensor. - /// It spins up a thread that will copy all 3 tensors to CPU and inject them - /// into the background queue which will be picked up by another set of thread - /// pools for streaming out to the thrift server (co-located on same host - /// now). - /// - /// This is used in cuda stream callback, which doesn't require to be - /// serialized with other callbacks, thus a separate thread is used to - /// maximize the overlapping with other callbacks. - /// - /// @param indices The 1D embedding index tensor, should skip on negative - /// value - /// @param weights The 2D tensor that each row(embeddings) is paired up with - /// relative element in - /// @param count A single element tensor that contains the number of indices - /// to be processed - /// @param blocking_tensor_copy whether to copy the tensors to be streamed in - /// a blocking manner - /// - /// @return None - void stream( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count, - bool blocking_tensor_copy = true); - /// storage tier counterpart of function get() virtual folly::SemiFuture> get_kv_db_async( const at::Tensor& indices, @@ -441,34 +415,6 @@ class EmbeddingKVDB : public std::enable_shared_from_this { return 0; } -#ifdef FBGEMM_FBCODE - folly::coro::Task tensor_stream( - const at::Tensor& indices, - const at::Tensor& weights); - /* - * Copy the indices, weights and count tensors and enqueue them for - * asynchronous stream. - */ - void copy_and_enqueue_stream_tensors( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count); - - /* - * Join the stream tensor copy thread, make sure the thread is properly - * finished before creating new. - */ - void join_stream_tensor_copy_thread(); - - /* - * FOR TESTING: Join the weight stream thread, make sure the thread is - * properly finished for destruction and testing. - */ - void join_weights_stream_thread(); - // FOR TESTING: get queue size. - uint64_t get_weights_to_stream_queue_size(); -#endif - private: /// Find non-negative embedding indices in and shard them into /// #cachelib_pools pieces to be lookedup in parallel @@ -599,17 +545,7 @@ class EmbeddingKVDB : public std::enable_shared_from_this { // -- commone path std::atomic total_cache_update_duration_{0}; - - // -- raw embedding streaming - bool enable_raw_embedding_streaming_; - int64_t res_store_shards_; - int64_t res_server_port_; - std::vector table_names_; - std::vector table_offsets_; - at::Tensor table_sizes_; - std::unique_ptr weights_stream_thread_; - folly::UMPSCQueue weights_to_stream_queue_; - std::unique_ptr stream_tensor_copy_thread_; + std::unique_ptr raw_embedding_streamer_; }; // class EmbeddingKVDB } // namespace kv_db diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/test/ssd_table_batched_embeddings_test.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/test/ssd_table_batched_embeddings_test.cpp index 0b2bed94bb..b0e5430a00 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/test/ssd_table_batched_embeddings_test.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/test/ssd_table_batched_embeddings_test.cpp @@ -10,12 +10,6 @@ #include #include #include "deeplearning/fbgemm/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h" -#ifdef FBGEMM_FBCODE -#include -#include "aiplatform/gmpp/experimental/training_ps/gen-cpp2/TrainingParameterServerService.h" -#include "servicerouter/client/cpp2/mocks/MockSRClientFactory.h" -#include "thrift/lib/cpp2/util/ScopedServerInterfaceThread.h" -#endif using namespace ::testing; constexpr int64_t EMBEDDING_DIMENSION = 8; @@ -83,21 +77,6 @@ class MockEmbeddingRocksDB : public ssd::EmbeddingRocksDB { (override)); }; -#ifdef FBGEMM_FBCODE -class MockTrainingParameterServerService - : public ::apache::thrift::ServiceHandler< - aiplatform::gmpp::experimental::training_ps:: - TrainingParameterServerService> { - public: - MOCK_METHOD( - folly::coro::Task>, - co_setEmbeddings, - (std::unique_ptr< - aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest>)); -}; -#endif - std::unique_ptr getMockEmbeddingRocksDB( int num_shards, const std::string& dir, @@ -187,125 +166,3 @@ TEST(SSDTableBatchedEmbeddingsTest, TestToggleCompactionFailOnThronw) { { mock_embedding_rocks->toggle_compaction(true); }, "Failed to toggle compaction to 1 with exception std::runtime_error: some error message"); } - -#ifdef FBGEMM_FBCODE -TEST(KvDbTableBatchedEmbeddingsTest, TestTensorStream) { - int num_shards = 8; - std::vector table_names = {"tb1", "tb2", "tb3"}; - std::vector table_offsets = {0, 100, 300}; - std::vector table_sizes = {0, 50, 200, 300}; - auto mock_embedding_rocks = getMockEmbeddingRocksDB( - num_shards, - "tensor_stream", - true, - table_names, - table_offsets, - table_sizes); - // Mock TrainingParameterServerService - auto mock_service = std::make_shared(); - auto mock_server = - std::make_shared( - mock_service, - "::1", - 0, - facebook::services::TLSConfig::applyDefaultsToThriftServer); - auto& mock_client_factory = - facebook::servicerouter::getMockSRClientFactory(false /* strict */); - mock_client_factory.registerMockService( - "realtime.delta.publish.esr", mock_server); - - auto invalid_ind = at::tensor( - {300, 301, 999}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - auto weights = at::randn( - {invalid_ind.size(0), EMBEDDING_DIMENSION}, - at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); - EXPECT_CALL(*mock_service, co_setEmbeddings(_)).Times(0); - folly::coro::blockingWait( - mock_embedding_rocks->tensor_stream(invalid_ind, weights)); - - auto ind = at::tensor( - {10, 2, 1, 150, 170, 230, 280}, - at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - weights = at::randn( - {ind.size(0), 8}, - at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); - EXPECT_CALL(*mock_service, co_setEmbeddings(_)) - .Times(3) // 3 shards with consistent hashing - .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke( - [](std::unique_ptr< - aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> - request) - -> folly::coro::Task< - std::unique_ptr> { - co_return std::make_unique< - aiplatform::gmpp::experimental::training_ps:: - SetEmbeddingsResponse>(); - })); - folly::coro::blockingWait(mock_embedding_rocks->tensor_stream(ind, weights)); -} -#endif - -#ifdef FBGEMM_FBCODE -TEST(KvDbTableBatchedEmbeddingsTest, TestStream) { - int num_shards = 8; - std::vector table_names = {"tb1", "tb2", "tb3"}; - std::vector table_offsets = {0, 100, 300}; - std::vector table_sizes = {0, 50, 200, 300}; - auto mock_embedding_rocks = getMockEmbeddingRocksDB( - num_shards, "test_stream", true, table_names, table_offsets, table_sizes); - auto ind = at::tensor( - {10, 2, 1, 150, 170, 230, 280}, - at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - auto weights = at::randn( - {ind.size(0), 8}, - at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); - auto count = at::tensor( - {ind.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - // stop the dequeue thread to get accurate queue size - mock_embedding_rocks->join_weights_stream_thread(); - - // blocking - mock_embedding_rocks->stream(ind, weights, count, true); - EXPECT_EQ(mock_embedding_rocks->get_weights_to_stream_queue_size(), 1); - // non-blocking - mock_embedding_rocks->stream(ind, weights, count, false); - EXPECT_EQ(mock_embedding_rocks->get_weights_to_stream_queue_size(), 1); - mock_embedding_rocks->join_stream_tensor_copy_thread(); - EXPECT_EQ(mock_embedding_rocks->get_weights_to_stream_queue_size(), 2); - mock_embedding_rocks.reset(); - - // E2E - auto default_response = - [](std::unique_ptr< - aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> - request) - -> folly::coro::Task> { - co_return std::make_unique< - aiplatform::gmpp::experimental::training_ps::SetEmbeddingsResponse>(); - }; - // Mock TrainingParameterServerService - auto mock_service = std::make_shared(); - auto mock_server = - std::make_shared( - mock_service, - "::1", - 0, - facebook::services::TLSConfig::applyDefaultsToThriftServer); - auto& mock_client_factory = - facebook::servicerouter::getMockSRClientFactory(false /* strict */); - mock_client_factory.registerMockService( - "realtime.delta.publish.esr", mock_server); - EXPECT_CALL(*mock_service, co_setEmbeddings(_)) - .Times(3) // 3 shards with consistent hashing - .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke(default_response)); - - mock_embedding_rocks = getMockEmbeddingRocksDB( - num_shards, "test_stream", true, table_names, table_offsets, table_sizes); - mock_embedding_rocks->stream(ind, weights, count, true); - // make sure dequeue finished. - std::this_thread::sleep_for(std::chrono::seconds(1)); - mock_embedding_rocks->join_weights_stream_thread(); -} -#endif