From 21cbfbbc79e8944a8a210180c706ccf280f54bcb Mon Sep 17 00:00:00 2001 From: Zheng Qi Date: Mon, 18 Aug 2025 15:40:25 -0700 Subject: [PATCH 1/2] Extract the res backend to a separate class and export to python side Differential Revision: D79192671 --- fbgemm_gpu/cmake/TbeInference.cmake | 1 + .../raw_embedding_streamer.h | 117 +++++++ .../raw_embedding_streamer.cpp | 314 ++++++++++++++++++ .../split_embeddings_cache_ops.cpp | 36 ++ .../tests/raw_embedding_streamer_test.cpp | 265 +++++++++++++++ .../kv_db_table_batched_embeddings.cpp | 254 ++------------ .../kv_db_table_batched_embeddings.h | 68 +--- .../ssd_table_batched_embeddings_test.cpp | 143 -------- 8 files changed, 755 insertions(+), 443 deletions(-) create mode 100644 fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h create mode 100644 fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp create mode 100644 fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp diff --git a/fbgemm_gpu/cmake/TbeInference.cmake b/fbgemm_gpu/cmake/TbeInference.cmake index 2c3fe235f3..4b75c5149b 100644 --- a/fbgemm_gpu/cmake/TbeInference.cmake +++ b/fbgemm_gpu/cmake/TbeInference.cmake @@ -22,6 +22,7 @@ gpu_cpp_library( ${FBGEMM_GPU}/src/split_embeddings_cache/lru_cache_populate_byte.cpp ${FBGEMM_GPU}/src/split_embeddings_cache/lxu_cache.cpp ${FBGEMM_GPU}/src/split_embeddings_cache/split_embeddings_cache_ops.cpp + ${FBGEMM_GPU}/src/split_embeddings_cache/raw_embedding_streamer.cpp GPU_SRCS ${FBGEMM_GPU}/src/split_embeddings_cache/lfu_cache_find.cu ${FBGEMM_GPU}/src/split_embeddings_cache/lfu_cache_populate.cu diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h new file mode 100644 index 0000000000..e44dc65025 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#ifdef FBGEMM_FBCODE +#include +#endif + +#include + +namespace fbgemm_gpu { + +struct StreamQueueItem { + at::Tensor indices; + at::Tensor weights; + at::Tensor count; + StreamQueueItem( + at::Tensor src_indices, + at::Tensor src_weights, + at::Tensor src_count) { + indices = std::move(src_indices); + weights = std::move(src_weights); + count = std::move(src_count); + } +}; + +class RawEmbeddingStreamer : public torch::jit::CustomClassHolder { + public: + explicit RawEmbeddingStreamer( + std::string unique_id, + bool enable_raw_embedding_streaming, + int64_t res_store_shards, + int64_t res_server_port, + std::vector table_names, + std::vector table_offsets, + const std::vector& table_sizes); + + virtual ~RawEmbeddingStreamer(); + + /// Stream out non-negative elements in and its paired embeddings + /// from for the first elements in the tensor. + /// It spins up a thread that will copy all 3 tensors to CPU and inject them + /// into the background queue which will be picked up by another set of thread + /// pools for streaming out to the thrift server (co-located on same host + /// now). + /// + /// This is used in cuda stream callback, which doesn't require to be + /// serialized with other callbacks, thus a separate thread is used to + /// maximize the overlapping with other callbacks. + /// + /// @param indices The 1D embedding index tensor, should skip on negative + /// value + /// @param weights The 2D tensor that each row(embeddings) is paired up with + /// relative element in + /// @param count A single element tensor that contains the number of indices + /// to be processed + /// @param blocking_tensor_copy whether to copy the tensors to be streamed in + /// a blocking manner + /// + /// @return None + void stream( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + bool require_tensor_copy, + bool blocking_tensor_copy = true); + +#ifdef FBGEMM_FBCODE + folly::coro::Task tensor_stream( + const at::Tensor& indices, + const at::Tensor& weights); + /* + * Copy the indices, weights and count tensors and enqueue them for + * asynchronous stream. + */ + void copy_and_enqueue_stream_tensors( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count); + + /* + * Join the stream tensor copy thread, make sure the thread is properly + * finished before creating new. + */ + void join_stream_tensor_copy_thread(); + + /* + * FOR TESTING: Join the weight stream thread, make sure the thread is + * properly finished for destruction and testing. + */ + void join_weights_stream_thread(); + // FOR TESTING: get queue size. + uint64_t get_weights_to_stream_queue_size(); +#endif + private: + std::atomic stop_{false}; + std::string unique_id_; + bool enable_raw_embedding_streaming_; + int64_t res_store_shards_; + int64_t res_server_port_; + std::vector table_names_; + std::vector table_offsets_; + at::Tensor table_sizes_; +#ifdef FBGEMM_FBCODE + std::unique_ptr weights_stream_thread_; + folly::UMPSCQueue weights_to_stream_queue_; + std::unique_ptr stream_tensor_copy_thread_; +#endif +}; + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp new file mode 100644 index 0000000000..ffee3ff672 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp @@ -0,0 +1,314 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef FBGEMM_FBCODE +#include +#include +#include +#include "aiplatform/gmpp/experimental/training_ps/gen-cpp2/TrainingParameterServerService.h" +#include "caffe2/torch/fb/distributed/wireSerializer/WireSerializer.h" +#include "servicerouter/client/cpp2/ClientParams.h" +#include "servicerouter/client/cpp2/ServiceRouter.h" +#include "torch/csrc/autograd/record_function_ops.h" +#include "torch/types.h" +#endif +#include "fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h" +#include "fbgemm_gpu/utils/dispatch_macros.h" + +namespace fbgemm_gpu { +namespace { + +#ifdef FBGEMM_FBCODE +/* + * Get the thrift client to the training parameter server service + * There is a destruction double free issue when wrapping the member + * variable under ifdef, and creating client is relatively cheap, so create this + * helper function to get the client just before sending requests. + */ +std::unique_ptr< + apache::thrift::Client> +get_res_client(int64_t res_server_port) { + auto& factory = facebook::servicerouter::cpp2::getClientFactory(); + auto& params = facebook::servicerouter::ClientParams().setSingleHost( + "::", res_server_port); + return factory.getSRClientUnique< + apache::thrift::Client>( + "realtime.delta.publish.esr", params); +} +#endif + +/// Read a scalar value from a tensor that is maybe a UVM tensor +/// Note that `tensor.item()` is not allowed on a UVM tensor in +/// PyTorch +inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Long + ? *(tensor.data_ptr()) + : *(tensor.data_ptr()); +} + +fbgemm_gpu::StreamQueueItem tensor_copy( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) { + auto num_sets = get_maybe_uvm_scalar(count); + auto new_indices = at::empty( + num_sets, at::TensorOptions().device(at::kCPU).dtype(indices.dtype())); + auto new_weights = at::empty( + {num_sets, weights.size(1)}, + at::TensorOptions().device(at::kCPU).dtype(weights.dtype())); + auto new_count = + at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( + weights.scalar_type(), "tensor_copy", [&] { + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "tensor_copy", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + auto new_indices_addr = new_indices.data_ptr(); + std::copy( + indices_addr, + indices_addr + num_sets, + new_indices_addr); // dst_start + + auto weights_addr = weights.data_ptr(); + auto new_weightss_addr = new_weights.data_ptr(); + std::copy( + weights_addr, + weights_addr + num_sets * weights.size(1), + new_weightss_addr); // dst_start + }); + }); + *new_count.data_ptr() = num_sets; + return fbgemm_gpu::StreamQueueItem{new_indices, new_weights, new_count}; +} + +} // namespace + +RawEmbeddingStreamer::RawEmbeddingStreamer( + std::string unique_id, + bool enable_raw_embedding_streaming, + int64_t res_store_shards, + int64_t res_server_port, + std::vector table_names, + std::vector table_offsets, + const std::vector& table_sizes) + : unique_id_(std::move(unique_id)), + enable_raw_embedding_streaming_(enable_raw_embedding_streaming), + res_store_shards_(res_store_shards), + res_server_port_(res_server_port), + table_names_(std::move(table_names)), + table_offsets_(std::move(table_offsets)), + table_sizes_(at::tensor(table_sizes)) { +#ifdef FBGEMM_FBCODE + if (enable_raw_embedding_streaming_) { + XLOG(INFO) << "[TBE_ID" << unique_id_ + << "] Raw embedding streaming enabled with res_server_port at" + << res_server_port; + // The first call to get the client is expensive, so eagerly get it here + auto _eager_client = get_res_client(res_server_port_); + + weights_stream_thread_ = std::make_unique([=, this] { + while (!stop_) { + auto stream_item_ptr = weights_to_stream_queue_.try_peek(); + if (!stream_item_ptr) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + continue; + } + if (stop_) { + return; + } + auto& indices = stream_item_ptr->indices; + auto& weights = stream_item_ptr->weights; + folly::stop_watch stop_watch; + folly::coro::blockingWait(tensor_stream(indices, weights)); + + weights_to_stream_queue_.dequeue(); + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ + << "] end stream queue size: " << weights_to_stream_queue_.size() + << " stream takes " << stop_watch.elapsed().count() << "ms"; + } + }); + } +#endif +} + +RawEmbeddingStreamer::~RawEmbeddingStreamer() { + stop_ = true; +#ifdef FBGEMM_FBCODE + if (enable_raw_embedding_streaming_) { + join_stream_tensor_copy_thread(); + join_weights_stream_thread(); + } +#endif +} + +void RawEmbeddingStreamer::stream( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + bool require_tensor_copy, + bool blocking_tensor_copy) { + if (!enable_raw_embedding_streaming_) { + return; + } +#ifdef FBGEMM_FBCODE + auto rec = torch::autograd::profiler::record_function_enter_new( + "## RawEmbeddingStreamer::stream_callback ##"); + if (!require_tensor_copy) { + StreamQueueItem stream_item(indices, weights, count); + weights_to_stream_queue_.enqueue(stream_item); + return; + } + if (blocking_tensor_copy) { + copy_and_enqueue_stream_tensors(indices, weights, count); + return; + } + // Make sure the previous thread is done before starting a new one + join_stream_tensor_copy_thread(); + // Cuda dispatches the host callbacks all in the same CPU thread. But the + // callbacks don't need to be serialized. + // So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA + // can continue executing other host callbacks, eg. get/evict. + stream_tensor_copy_thread_ = std::make_unique([=, this]() { + copy_and_enqueue_stream_tensors(indices, weights, count); + }); + rec->record.end(); +#endif +} + +#ifdef FBGEMM_FBCODE +folly::coro::Task RawEmbeddingStreamer::tensor_stream( + const at::Tensor& indices, + const at::Tensor& weights) { + using namespace ::aiplatform::gmpp::experimental::training_ps; + if (indices.size(0) != weights.size(0)) { + XLOG(ERR) << "[TBE_ID" << unique_id_ + << "] Indices and weights size mismatched " << indices.size(0) + << " " << weights.size(0); + co_return; + } + folly::stop_watch stop_watch; + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ + << "] send streaming request: indices = " << indices.size(0) + << ", weights = " << weights.size(0); + + auto biggest_idx = table_sizes_.index({table_sizes_.size(0) - 1}); + auto mask = + at::logical_and(indices >= 0, indices < biggest_idx).nonzero().squeeze(); + auto filtered_indices = indices.index_select(0, mask); + auto filtered_weights = weights.index_select(0, mask); + auto num_invalid_indices = indices.size(0) - filtered_indices.size(0); + if (num_invalid_indices > 0) { + XLOG(INFO) << "[TBE_ID" << unique_id_ + << "] number of invalid indices: " << num_invalid_indices; + } + // 1. Transform local row indices to embedding table global row indices + at::Tensor table_indices = + (at::searchsorted(table_sizes_, filtered_indices, false, true) - 1) + .to(torch::kInt8); + auto tb_ac = table_indices.accessor(); + auto indices_ac = filtered_indices.accessor(); + auto tb_sizes_ac = table_sizes_.accessor(); + std::vector global_indices(tb_ac.size(0), 0); + std::vector shard_indices(tb_ac.size(0), 0); + + for (int i = 0; i < tb_ac.size(0); ++i) { + auto tb_idx = tb_ac[i]; + global_indices[i] = + indices_ac[i] - tb_sizes_ac[tb_idx] + table_offsets_[tb_idx]; + // hash to shard + // if we do row range sharding, also shard here. + auto fqn = table_names_[tb_idx]; + auto hash_key = folly::to(fqn, global_indices[i]); + auto shard_id = + furcHash(hash_key.data(), hash_key.size(), res_store_shards_); + shard_indices[i] = shard_id; + } + auto global_indices_tensor = at::tensor(global_indices); + auto shard_indices_tensor = at::tensor(shard_indices); + auto total_rows = global_indices_tensor.size(0); + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ << "] hash and gloablize rows " << total_rows + << " in: " << stop_watch.elapsed().count() << "ms"; + stop_watch.reset(); + + auto res_client = get_res_client(res_server_port_); + // 2. Split by shards + for (int i = 0; i < res_store_shards_; ++i) { + auto shrad_mask = shard_indices_tensor.eq(i).nonzero().squeeze(); + auto table_indices_masked = table_indices.index_select(0, shrad_mask); + auto rows_in_shard = table_indices_masked.numel(); + if (rows_in_shard == 0) { + continue; + } + auto global_indices_masked = + global_indices_tensor.index_select(0, shrad_mask); + auto weights_masked = filtered_weights.index_select(0, shrad_mask); + + if (weights_masked.size(0) != rows_in_shard || + global_indices_masked.numel() != rows_in_shard) { + XLOG(ERR) + << "[TBE_ID" << unique_id_ + << "] don't send the request for size mismatched tensors table: " + << rows_in_shard << " weights: " << weights_masked.size(0) + << " global_indices: " << global_indices_masked.numel(); + continue; + } + SetEmbeddingsRequest req; + req.shardId() = i; + req.fqns() = table_names_; + + req.tableIndices() = + torch::distributed::wireDumpTensor(table_indices_masked); + req.rowIndices() = + torch::distributed::wireDumpTensor(global_indices_masked); + req.weights() = torch::distributed::wireDumpTensor(weights_masked); + co_await res_client->co_setEmbeddings(req); + } + co_return; +} + +void RawEmbeddingStreamer::copy_and_enqueue_stream_tensors( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) { + auto rec = torch::autograd::profiler::record_function_enter_new( + "## RawEmbeddingStreamer::copy_and_enqueue_stream_tensors ##"); + auto stream_item = tensor_copy(indices, weights, count); + weights_to_stream_queue_.enqueue(stream_item); + rec->record.end(); +} + +void RawEmbeddingStreamer::join_stream_tensor_copy_thread() { + auto rec = torch::autograd::profiler::record_function_enter_new( + "## RawEmbeddingStreamer::join_stream_tensor_copy_thread ##"); + if (stream_tensor_copy_thread_ != nullptr && + stream_tensor_copy_thread_->joinable()) { + stream_tensor_copy_thread_->join(); + } + rec->record.end(); +} + +void RawEmbeddingStreamer::join_weights_stream_thread() { + if (weights_stream_thread_ != nullptr && weights_stream_thread_->joinable()) { + stop_ = true; + weights_stream_thread_->join(); + } +} + +uint64_t RawEmbeddingStreamer::get_weights_to_stream_queue_size() { + return weights_to_stream_queue_.size(); +} +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp index 83d86fd097..89b1e9d720 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp @@ -7,6 +7,7 @@ */ #include "common.h" +#include "fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h" namespace { @@ -73,4 +74,39 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_META("lxu_cache_lookup", lxu_cache_lookup_meta); } +static auto raw_embedding_streamer = + torch::class_( + "fbgemm", + "RawEmbeddingStreamer") + .def( + torch::init< + std::string, + bool, + int64_t, + int64_t, + std::vector, + std::vector, + std::vector>(), + "", + { + torch::arg("unique_id") = 0, + torch::arg("enable_raw_embedding_streaming") = false, + torch::arg("res_store_shards") = 0, + torch::arg("res_server_port") = 0, + torch::arg("table_names") = torch::List(), + torch::arg("table_offsets") = torch::List(), + torch::arg("table_sizes") = torch::List(), + }) + .def( + "stream", + &fbgemm_gpu::RawEmbeddingStreamer::stream, + "", + { + torch::arg("indices"), + torch::arg("weights"), + torch::arg("count"), + torch::arg("require_tensor_copy"), + torch::arg("blocking_tensor_copy"), + }); + } // namespace diff --git a/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp b/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp new file mode 100644 index 0000000000..3b408f7a93 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp @@ -0,0 +1,265 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu/src/split_embeddings_cache:raw_embedding_streamer +#ifdef FBGEMM_FBCODE +#include +#include "aiplatform/gmpp/experimental/training_ps/gen-cpp2/TrainingParameterServerService.h" +#include "servicerouter/client/cpp2/mocks/MockSRClientFactory.h" +#include "thrift/lib/cpp2/util/ScopedServerInterfaceThread.h" +#endif + +using namespace ::testing; +using namespace fbgemm_gpu; +constexpr int64_t EMBEDDING_DIMENSION = 8; + +#ifdef FBGEMM_FBCODE +class MockTrainingParameterServerService + : public ::apache::thrift::ServiceHandler< + aiplatform::gmpp::experimental::training_ps:: + TrainingParameterServerService> { + public: + MOCK_METHOD( + folly::coro::Task>, + co_setEmbeddings, + (std::unique_ptr< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest>)); +}; +#endif + +static std::unique_ptr +getRawEmbeddingStreamer( + const std::string& unique_id, + bool enable_raw_embedding_streaming = false, + const std::vector& table_names = {}, + const std::vector& table_offsets = {}, + const std::vector& table_sizes = {}) { + return std::make_unique( + unique_id, + enable_raw_embedding_streaming, + 3, // res_store_shards + 0, // res_server_port + table_names, + table_offsets, + table_sizes); +} + +TEST(RawEmbeddingStreamerTest, TestConstructorAndDestructor) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_constructor", false, table_names, table_offsets, table_sizes); + EXPECT_NE(streamer, nullptr); +} + +TEST(RawEmbeddingStreamerTest, TestStreamWithoutStreaming) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_no_streaming", false, table_names, table_offsets, table_sizes); + + auto indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + auto count = at::tensor( + {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + // Should not crash when streaming is disabled + streamer->stream(indices, weights, count, true, true); +} + +#ifdef FBGEMM_FBCODE +TEST(RawEmbeddingStreamerTest, TestTensorStream) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_tensor_stream", true, table_names, table_offsets, table_sizes); + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + // Test with invalid indices - should not call service + auto invalid_indices = at::tensor( + {300, 301, 999}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {invalid_indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + EXPECT_CALL(*mock_service, co_setEmbeddings(_)).Times(0); + folly::coro::blockingWait(streamer->tensor_stream(invalid_indices, weights)); + + // Test with valid indices - should call service + auto valid_indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + weights = at::randn( + {valid_indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + EXPECT_CALL(*mock_service, co_setEmbeddings(_)) + .Times(3) // 3 shards with consistent hashing + .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke( + [](std::unique_ptr< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> + request) + -> folly::coro::Task< + std::unique_ptr> { + co_return std::make_unique< + aiplatform::gmpp::experimental::training_ps:: + SetEmbeddingsResponse>(); + })); + folly::coro::blockingWait(streamer->tensor_stream(valid_indices, weights)); +} + +TEST(RawEmbeddingStreamerTest, TestStreamWithCopy) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_stream_copy", true, table_names, table_offsets, table_sizes); + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + auto indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + auto count = at::tensor( + {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + // Stop the dequeue thread to get accurate queue size + streamer->join_weights_stream_thread(); + + // Test blocking tensor copy + streamer->stream(indices, weights, count, true, true); + EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); + + // Test non-blocking tensor copy + streamer->stream(indices, weights, count, true, false); + EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); + streamer->join_stream_tensor_copy_thread(); + EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 2); +} + +TEST(RawEmbeddingStreamerTest, TestStreamE2E) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + auto default_response = + [](std::unique_ptr< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> + request) + -> folly::coro::Task> { + co_return std::make_unique< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsResponse>(); + }; + + EXPECT_CALL(*mock_service, co_setEmbeddings(_)) + .Times(3) // 3 shards with consistent hashing + .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke(default_response)); + + auto streamer = getRawEmbeddingStreamer( + "test_stream_e2e", true, table_names, table_offsets, table_sizes); + + auto indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + auto count = at::tensor( + {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + streamer->stream(indices, weights, count, true, true); + // Make sure dequeue finished + std::this_thread::sleep_for(std::chrono::seconds(1)); + streamer->join_weights_stream_thread(); +} + +TEST(RawEmbeddingStreamerTest, TestMismatchedIndicesWeights) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_mismatch", true, table_names, table_offsets, table_sizes); + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + // Test with mismatched sizes - should not call service + auto indices = at::tensor( + {10, 2, 1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {5, EMBEDDING_DIMENSION}, // Different size than indices + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + + EXPECT_CALL(*mock_service, co_setEmbeddings(_)).Times(0); + folly::coro::blockingWait(streamer->tensor_stream(indices, weights)); +} +#endif diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 402633b583..55e548af0f 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -15,14 +15,6 @@ #include "kv_db_cuda_utils.h" #endif #include "torch/csrc/autograd/record_function_ops.h" -#ifdef FBGEMM_FBCODE -#include -#include "aiplatform/gmpp/experimental/training_ps/gen-cpp2/TrainingParameterServerService.h" -#include "caffe2/torch/fb/distributed/wireSerializer/WireSerializer.h" -#include "servicerouter/client/cpp2/ClientParams.h" -#include "servicerouter/client/cpp2/ServiceRouter.h" -#include "torch/types.h" -#endif namespace kv_db { @@ -36,28 +28,6 @@ inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) { ? *(tensor.data_ptr()) : *(tensor.data_ptr()); } - -#ifdef FBGEMM_FBCODE -/* - * Get the thrift client to the training parameter server service - * There is a destruction double free issue when wrapping the member - * variable under ifdef, and creating client is relatively cheap, so create this - * helper function to get the client just before sending requests. - */ -std::unique_ptr< - apache::thrift::Client> -get_res_client(int64_t res_server_port) { - auto& factory = facebook::servicerouter::cpp2::getClientFactory(); - auto& params = facebook::servicerouter::ClientParams().setSingleHost( - "::", res_server_port); - return factory.getSRClientUnique< - apache::thrift::Client>( - "realtime.delta.publish.esr", params); -} -#endif - }; // namespace QueueItem tensor_copy( @@ -118,12 +88,15 @@ EmbeddingKVDB::EmbeddingKVDB( max_D_(max_D), executor_tp_(std::make_unique(num_shards)), enable_async_update_(enable_async_update), - enable_raw_embedding_streaming_(enable_raw_embedding_streaming), - res_store_shards_(res_store_shards), - res_server_port_(res_server_port), - table_names_(std::move(table_names)), - table_offsets_(std::move(table_offsets)), - table_sizes_(at::tensor(table_sizes)) { + raw_embedding_streamer_( + std::make_unique( + std::to_string(unique_id), + enable_raw_embedding_streaming, + res_store_shards, + res_server_port, + std::move(table_names), + std::move(table_offsets), + table_sizes)) { CHECK(num_shards > 0); if (cache_size_gb > 0) { l2_cache::CacheLibCache::CacheConfig cache_config; @@ -139,8 +112,6 @@ EmbeddingKVDB::EmbeddingKVDB( XLOG(INFO) << "[TBE_ID" << unique_id_ << "] L2 created with " << num_shards_ << " shards, dimension:" << max_D_ << ", enable_async_update_:" << enable_async_update_ - << ", enable_raw_embedding_streaming_:" - << enable_raw_embedding_streaming_ << ", cache_size_gb:" << cache_size_gb; if (enable_async_update_) { @@ -167,38 +138,6 @@ EmbeddingKVDB::EmbeddingKVDB( } }); } -#ifdef FBGEMM_FBCODE - if (enable_raw_embedding_streaming_) { - XLOG(INFO) << "[TBE_ID" << unique_id_ - << "] Raw embedding streaming enabled with res_server_port at" - << res_server_port; - // The first call to get the client is expensive, so eagerly get it here - auto _eager_client = get_res_client(res_server_port_); - - weights_stream_thread_ = std::make_unique([=, this] { - while (!stop_) { - auto stream_item_ptr = weights_to_stream_queue_.try_peek(); - if (!stream_item_ptr) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - continue; - } - if (stop_) { - return; - } - auto& indices = stream_item_ptr->indices; - auto& weights = stream_item_ptr->weights; - folly::stop_watch stop_watch; - folly::coro::blockingWait(tensor_stream(indices, weights)); - - weights_to_stream_queue_.dequeue(); - XLOG_EVERY_MS(INFO, 60000) - << "[TBE_ID" << unique_id_ - << "] end stream queue size: " << weights_to_stream_queue_.size() - << " stream takes " << stop_watch.elapsed().count() << "ms"; - } - }); - } -#endif } EmbeddingKVDB::~EmbeddingKVDB() { @@ -206,141 +145,8 @@ EmbeddingKVDB::~EmbeddingKVDB() { if (enable_async_update_) { cache_filling_thread_->join(); } -#ifdef FBGEMM_FBCODE - if (enable_raw_embedding_streaming_) { - join_stream_tensor_copy_thread(); - join_weights_stream_thread(); - } -#endif -} - -#ifdef FBGEMM_FBCODE -folly::coro::Task EmbeddingKVDB::tensor_stream( - const at::Tensor& indices, - const at::Tensor& weights) { - using namespace ::aiplatform::gmpp::experimental::training_ps; - if (indices.size(0) != weights.size(0)) { - XLOG(ERR) << "[TBE_ID" << unique_id_ - << "] Indices and weights size mismatched " << indices.size(0) - << " " << weights.size(0); - co_return; - } - folly::stop_watch stop_watch; - XLOG_EVERY_MS(INFO, 60000) - << "[TBE_ID" << unique_id_ - << "] send streaming request: indices = " << indices.size(0) - << ", weights = " << weights.size(0); - - auto biggest_idx = table_sizes_.index({table_sizes_.size(0) - 1}); - auto mask = - at::logical_and(indices >= 0, indices < biggest_idx).nonzero().squeeze(); - auto filtered_indices = indices.index_select(0, mask); - auto filtered_weights = weights.index_select(0, mask); - auto num_invalid_indices = indices.size(0) - filtered_indices.size(0); - if (num_invalid_indices > 0) { - XLOG(INFO) << "[TBE_ID" << unique_id_ - << "] number of invalid indices: " << num_invalid_indices; - } - // 1. Transform local row indices to embedding table global row indices - at::Tensor table_indices = - (at::searchsorted(table_sizes_, filtered_indices, false, true) - 1) - .to(torch::kInt8); - auto tb_ac = table_indices.accessor(); - auto indices_ac = filtered_indices.accessor(); - auto tb_sizes_ac = table_sizes_.accessor(); - std::vector global_indices(tb_ac.size(0), 0); - std::vector shard_indices(tb_ac.size(0), 0); - - for (int i = 0; i < tb_ac.size(0); ++i) { - int tb_idx = tb_ac[i]; - global_indices[i] = - indices_ac[i] - tb_sizes_ac[tb_idx] + table_offsets_[tb_idx]; - // hash to shard - // if we do row range sharding, also shard here. - auto fqn = table_names_[tb_idx]; - auto hash_key = folly::to(fqn, global_indices[i]); - auto shard_id = - furcHash(hash_key.data(), hash_key.size(), res_store_shards_); - shard_indices[i] = shard_id; - } - auto global_indices_tensor = at::tensor(global_indices); - auto shard_indices_tensor = at::tensor(shard_indices); - auto total_rows = global_indices_tensor.size(0); - XLOG_EVERY_MS(INFO, 60000) - << "[TBE_ID" << unique_id_ << "] hash and gloablize rows " << total_rows - << " in: " << stop_watch.elapsed().count() << "ms"; - stop_watch.reset(); - - auto res_client = get_res_client(res_server_port_); - // 2. Split by shards - for (int i = 0; i < res_store_shards_; ++i) { - auto shrad_mask = shard_indices_tensor.eq(i).nonzero().squeeze(); - auto table_indices_masked = table_indices.index_select(0, shrad_mask); - auto rows_in_shard = table_indices_masked.numel(); - if (rows_in_shard == 0) { - continue; - } - auto global_indices_masked = - global_indices_tensor.index_select(0, shrad_mask); - auto weights_masked = filtered_weights.index_select(0, shrad_mask); - - if (weights_masked.size(0) != rows_in_shard || - global_indices_masked.numel() != rows_in_shard) { - XLOG(ERR) - << "[TBE_ID" << unique_id_ - << "] don't send the request for size mismatched tensors table: " - << rows_in_shard << " weights: " << weights_masked.size(0) - << " global_indices: " << global_indices_masked.numel(); - continue; - } - SetEmbeddingsRequest req; - req.shardId() = i; - req.fqns() = table_names_; - - req.tableIndices() = - torch::distributed::wireDumpTensor(table_indices_masked); - req.rowIndices() = - torch::distributed::wireDumpTensor(global_indices_masked); - req.weights() = torch::distributed::wireDumpTensor(weights_masked); - co_await res_client->co_setEmbeddings(req); - } - co_return; -} - -void EmbeddingKVDB::copy_and_enqueue_stream_tensors( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count) { - auto rec = torch::autograd::profiler::record_function_enter_new( - "## EmbeddingKVDB::copy_and_enqueue_stream_tensors ##"); - auto stream_item = - tensor_copy(indices, weights, count, kv_db::RocksdbWriteMode::STREAM); - weights_to_stream_queue_.enqueue(stream_item); - rec->record.end(); -} - -void EmbeddingKVDB::join_stream_tensor_copy_thread() { - auto rec = torch::autograd::profiler::record_function_enter_new( - "## EmbeddingKVDB::join_stream_tensor_copy_thread ##"); - if (stream_tensor_copy_thread_ != nullptr && - stream_tensor_copy_thread_->joinable()) { - stream_tensor_copy_thread_->join(); - } - rec->record.end(); -} - -void EmbeddingKVDB::join_weights_stream_thread() { - if (weights_stream_thread_ != nullptr && weights_stream_thread_->joinable()) { - stop_ = true; - weights_stream_thread_->join(); - } } -uint64_t EmbeddingKVDB::get_weights_to_stream_queue_size() { - return weights_to_stream_queue_.size(); -} -#endif - void EmbeddingKVDB::update_cache_and_storage( const at::Tensor& indices, const at::Tensor& weights, @@ -491,8 +297,14 @@ void EmbeddingKVDB::stream_cuda( check_tensor_type_consistency(indices, weights); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); - std::function* functor = new std::function( - [=]() { self->stream(indices, weights, count, blocking_tensor_copy); }); + std::function* functor = new std::function([=]() { + self->raw_embedding_streamer_->stream( + indices, + weights, + count, + true, /*require_tensor_copy*/ + blocking_tensor_copy); + }); AT_CUDA_CHECK(cudaStreamAddCallback( at::cuda::getCurrentCUDAStream(), kv_db_utils::cuda_callback_func, @@ -508,8 +320,9 @@ void EmbeddingKVDB::stream_sync_cuda() { "## EmbeddingKVDB::stream_sync_cuda ##"); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); - std::function* functor = new std::function( - [=]() { self->join_stream_tensor_copy_thread(); }); + std::function* functor = new std::function([=]() { + self->raw_embedding_streamer_->join_stream_tensor_copy_thread(); + }); AT_CUDA_CHECK(cudaStreamAddCallback( at::cuda::getCurrentCUDAStream(), kv_db_utils::cuda_callback_func, @@ -593,7 +406,6 @@ void EmbeddingKVDB::set( return; } CHECK_EQ(max_D_, weights.size(1)); - auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::set_callback ##"); // defer the L2 cache/rocksdb update to the background thread as it could @@ -692,32 +504,6 @@ void EmbeddingKVDB::get( rec->record.end(); } -void EmbeddingKVDB::stream( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count, - bool blocking_tensor_copy) { - if (!enable_raw_embedding_streaming_) { - return; - } - auto rec = torch::autograd::profiler::record_function_enter_new( - "## EmbeddingKVDB::stream_callback ##"); - if (blocking_tensor_copy) { - copy_and_enqueue_stream_tensors(indices, weights, count); - return; - } - // Make sure the previous thread is done before starting a new one - join_stream_tensor_copy_thread(); - // Cuda dispatches the host callbacks all in the same CPU thread. But the - // callbacks don't need to be serialized. - // So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA - // can continue executing other host callbacks, eg. get/evict. - stream_tensor_copy_thread_ = std::make_unique([=, this]() { - copy_and_enqueue_stream_tensors(indices, weights, count); - }); - rec->record.end(); -} - std::shared_ptr EmbeddingKVDB::get_cache( const at::Tensor& indices, const at::Tensor& count) { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 656cb61067..40911249f6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -42,6 +42,7 @@ #include #include "../dram_kv_embedding_cache/feature_evict.h" #include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h" +#include "fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h" #include "fbgemm_gpu/utils/dispatch_macros.h" namespace ssd { @@ -202,33 +203,6 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const at::Tensor& count, int64_t sleep_ms = 0); - /// Stream out non-negative elements in and its paired embeddings - /// from for the first elements in the tensor. - /// It spins up a thread that will copy all 3 tensors to CPU and inject them - /// into the background queue which will be picked up by another set of thread - /// pools for streaming out to the thrift server (co-located on same host - /// now). - /// - /// This is used in cuda stream callback, which doesn't require to be - /// serialized with other callbacks, thus a separate thread is used to - /// maximize the overlapping with other callbacks. - /// - /// @param indices The 1D embedding index tensor, should skip on negative - /// value - /// @param weights The 2D tensor that each row(embeddings) is paired up with - /// relative element in - /// @param count A single element tensor that contains the number of indices - /// to be processed - /// @param blocking_tensor_copy whether to copy the tensors to be streamed in - /// a blocking manner - /// - /// @return None - void stream( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count, - bool blocking_tensor_copy = true); - /// storage tier counterpart of function get() virtual folly::SemiFuture> get_kv_db_async( const at::Tensor& indices, @@ -441,34 +415,6 @@ class EmbeddingKVDB : public std::enable_shared_from_this { return 0; } -#ifdef FBGEMM_FBCODE - folly::coro::Task tensor_stream( - const at::Tensor& indices, - const at::Tensor& weights); - /* - * Copy the indices, weights and count tensors and enqueue them for - * asynchronous stream. - */ - void copy_and_enqueue_stream_tensors( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count); - - /* - * Join the stream tensor copy thread, make sure the thread is properly - * finished before creating new. - */ - void join_stream_tensor_copy_thread(); - - /* - * FOR TESTING: Join the weight stream thread, make sure the thread is - * properly finished for destruction and testing. - */ - void join_weights_stream_thread(); - // FOR TESTING: get queue size. - uint64_t get_weights_to_stream_queue_size(); -#endif - private: /// Find non-negative embedding indices in and shard them into /// #cachelib_pools pieces to be lookedup in parallel @@ -599,17 +545,7 @@ class EmbeddingKVDB : public std::enable_shared_from_this { // -- commone path std::atomic total_cache_update_duration_{0}; - - // -- raw embedding streaming - bool enable_raw_embedding_streaming_; - int64_t res_store_shards_; - int64_t res_server_port_; - std::vector table_names_; - std::vector table_offsets_; - at::Tensor table_sizes_; - std::unique_ptr weights_stream_thread_; - folly::UMPSCQueue weights_to_stream_queue_; - std::unique_ptr stream_tensor_copy_thread_; + std::unique_ptr raw_embedding_streamer_; }; // class EmbeddingKVDB } // namespace kv_db diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/test/ssd_table_batched_embeddings_test.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/test/ssd_table_batched_embeddings_test.cpp index 0b2bed94bb..b0e5430a00 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/test/ssd_table_batched_embeddings_test.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/test/ssd_table_batched_embeddings_test.cpp @@ -10,12 +10,6 @@ #include #include #include "deeplearning/fbgemm/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h" -#ifdef FBGEMM_FBCODE -#include -#include "aiplatform/gmpp/experimental/training_ps/gen-cpp2/TrainingParameterServerService.h" -#include "servicerouter/client/cpp2/mocks/MockSRClientFactory.h" -#include "thrift/lib/cpp2/util/ScopedServerInterfaceThread.h" -#endif using namespace ::testing; constexpr int64_t EMBEDDING_DIMENSION = 8; @@ -83,21 +77,6 @@ class MockEmbeddingRocksDB : public ssd::EmbeddingRocksDB { (override)); }; -#ifdef FBGEMM_FBCODE -class MockTrainingParameterServerService - : public ::apache::thrift::ServiceHandler< - aiplatform::gmpp::experimental::training_ps:: - TrainingParameterServerService> { - public: - MOCK_METHOD( - folly::coro::Task>, - co_setEmbeddings, - (std::unique_ptr< - aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest>)); -}; -#endif - std::unique_ptr getMockEmbeddingRocksDB( int num_shards, const std::string& dir, @@ -187,125 +166,3 @@ TEST(SSDTableBatchedEmbeddingsTest, TestToggleCompactionFailOnThronw) { { mock_embedding_rocks->toggle_compaction(true); }, "Failed to toggle compaction to 1 with exception std::runtime_error: some error message"); } - -#ifdef FBGEMM_FBCODE -TEST(KvDbTableBatchedEmbeddingsTest, TestTensorStream) { - int num_shards = 8; - std::vector table_names = {"tb1", "tb2", "tb3"}; - std::vector table_offsets = {0, 100, 300}; - std::vector table_sizes = {0, 50, 200, 300}; - auto mock_embedding_rocks = getMockEmbeddingRocksDB( - num_shards, - "tensor_stream", - true, - table_names, - table_offsets, - table_sizes); - // Mock TrainingParameterServerService - auto mock_service = std::make_shared(); - auto mock_server = - std::make_shared( - mock_service, - "::1", - 0, - facebook::services::TLSConfig::applyDefaultsToThriftServer); - auto& mock_client_factory = - facebook::servicerouter::getMockSRClientFactory(false /* strict */); - mock_client_factory.registerMockService( - "realtime.delta.publish.esr", mock_server); - - auto invalid_ind = at::tensor( - {300, 301, 999}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - auto weights = at::randn( - {invalid_ind.size(0), EMBEDDING_DIMENSION}, - at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); - EXPECT_CALL(*mock_service, co_setEmbeddings(_)).Times(0); - folly::coro::blockingWait( - mock_embedding_rocks->tensor_stream(invalid_ind, weights)); - - auto ind = at::tensor( - {10, 2, 1, 150, 170, 230, 280}, - at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - weights = at::randn( - {ind.size(0), 8}, - at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); - EXPECT_CALL(*mock_service, co_setEmbeddings(_)) - .Times(3) // 3 shards with consistent hashing - .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke( - [](std::unique_ptr< - aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> - request) - -> folly::coro::Task< - std::unique_ptr> { - co_return std::make_unique< - aiplatform::gmpp::experimental::training_ps:: - SetEmbeddingsResponse>(); - })); - folly::coro::blockingWait(mock_embedding_rocks->tensor_stream(ind, weights)); -} -#endif - -#ifdef FBGEMM_FBCODE -TEST(KvDbTableBatchedEmbeddingsTest, TestStream) { - int num_shards = 8; - std::vector table_names = {"tb1", "tb2", "tb3"}; - std::vector table_offsets = {0, 100, 300}; - std::vector table_sizes = {0, 50, 200, 300}; - auto mock_embedding_rocks = getMockEmbeddingRocksDB( - num_shards, "test_stream", true, table_names, table_offsets, table_sizes); - auto ind = at::tensor( - {10, 2, 1, 150, 170, 230, 280}, - at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - auto weights = at::randn( - {ind.size(0), 8}, - at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); - auto count = at::tensor( - {ind.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - // stop the dequeue thread to get accurate queue size - mock_embedding_rocks->join_weights_stream_thread(); - - // blocking - mock_embedding_rocks->stream(ind, weights, count, true); - EXPECT_EQ(mock_embedding_rocks->get_weights_to_stream_queue_size(), 1); - // non-blocking - mock_embedding_rocks->stream(ind, weights, count, false); - EXPECT_EQ(mock_embedding_rocks->get_weights_to_stream_queue_size(), 1); - mock_embedding_rocks->join_stream_tensor_copy_thread(); - EXPECT_EQ(mock_embedding_rocks->get_weights_to_stream_queue_size(), 2); - mock_embedding_rocks.reset(); - - // E2E - auto default_response = - [](std::unique_ptr< - aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> - request) - -> folly::coro::Task> { - co_return std::make_unique< - aiplatform::gmpp::experimental::training_ps::SetEmbeddingsResponse>(); - }; - // Mock TrainingParameterServerService - auto mock_service = std::make_shared(); - auto mock_server = - std::make_shared( - mock_service, - "::1", - 0, - facebook::services::TLSConfig::applyDefaultsToThriftServer); - auto& mock_client_factory = - facebook::servicerouter::getMockSRClientFactory(false /* strict */); - mock_client_factory.registerMockService( - "realtime.delta.publish.esr", mock_server); - EXPECT_CALL(*mock_service, co_setEmbeddings(_)) - .Times(3) // 3 shards with consistent hashing - .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke(default_response)); - - mock_embedding_rocks = getMockEmbeddingRocksDB( - num_shards, "test_stream", true, table_names, table_offsets, table_sizes); - mock_embedding_rocks->stream(ind, weights, count, true); - // make sure dequeue finished. - std::this_thread::sleep_for(std::chrono::seconds(1)); - mock_embedding_rocks->join_weights_stream_thread(); -} -#endif From 4b8377997bdee8300fd7460d6b58e63ef6e61773 Mon Sep 17 00:00:00 2001 From: Zheng Qi Date: Wed, 20 Aug 2025 21:17:19 -0700 Subject: [PATCH 2/2] Add tracking and streaming logic to SplitTableBatchedEmbeddingBagsCodegen (#4741) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4741 X-link: https://github.com/facebookresearch/FBGEMM/pull/1762 It follows similar logic to SSD TBE https://fburl.com/code/fxdcxma3 It tries to 1. store the updated ids/count 2. next iteration streams out the updated embeddings and ids, before the embedding cache are populated again. the prefetch pipeline logic also the same to SSDTBE. Differential Revision: D78438757 --- ...t_table_batched_embeddings_ops_training.py | 104 ++++++++++++++++ fbgemm_gpu/test/tbe/training/forward_test.py | 111 ++++++++++++++++++ 2 files changed, 215 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 05e9cdb8c8..068fed8dea 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -22,6 +22,7 @@ import torch # usort:skip from torch import nn, Tensor # usort:skip +from torch.autograd.profiler import record_function # usort:skip # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -626,6 +627,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): lxu_cache_locations_list: List[Tensor] lxu_cache_locations_empty: Tensor timesteps_prefetched: List[int] + prefetched_info: List[Tuple[Tensor, Tensor]] record_cache_metrics: RecordCacheMetrics # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized. uvm_cache_stats: torch.Tensor @@ -690,6 +692,8 @@ def __init__( # noqa C901 embedding_table_index_type: torch.dtype = torch.int64, embedding_table_offset_type: torch.dtype = torch.int64, embedding_shard_info: Optional[List[Tuple[int, int, int, int]]] = None, + enable_raw_embedding_streaming: bool = False, + res_params: Optional[RESParams] = None, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() self.uuid = str(uuid.uuid4()) @@ -700,6 +704,7 @@ def __init__( # noqa C901 ) self.logging_table_name: str = self.get_table_name_for_logging(table_names) + self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming self.pooling_mode = pooling_mode self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE @@ -1460,6 +1465,30 @@ def __init__( # noqa C901 ) self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type + self.prefetched_info: List[Tuple[Tensor, Tensor]] = torch.jit.annotate( + List[Tuple[Tensor, Tensor]], [] + ) + if self.enable_raw_embedding_streaming: + self.res_params: RESParams = res_params or RESParams() + self.res_params.table_sizes = [0] + list(accumulate(rows)) + res_port_from_env = os.getenv("LOCAL_RES_PORT") + self.res_params.res_server_port = ( + int(res_port_from_env) if res_port_from_env else 0 + ) + # pyre-fixme[4]: Attribute must be annotated. + self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer( + self.uuid, + self.enable_raw_embedding_streaming, + self.res_params.res_store_shards, + self.res_params.res_server_port, + self.res_params.table_names, + self.res_params.table_offsets, + self.res_params.table_sizes, + ) + logging.info( + f"{self.uuid} raw embedding streaming enabled with {self.res_params=}" + ) + @torch.jit.ignore def log(self, msg: str) -> None: """ @@ -2521,7 +2550,13 @@ def _prefetch( self.local_uvm_cache_stats.zero_() self._report_io_size_count("prefetch_input", indices) + # streaming before updating the cache + self.raw_embedding_stream() + final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32) + linear_cache_indices_merged = torch.zeros( + 0, dtype=indices.dtype, device=indices.device + ) for ( partial_indices, partial_lxu_cache_locations, @@ -2537,6 +2572,9 @@ def _prefetch( vbe_metadata.max_B if vbe_metadata is not None else -1, base_offset, ) + linear_cache_indices_merged = torch.cat( + [linear_cache_indices_merged, linear_cache_indices] + ) if ( self.record_cache_metrics.record_cache_miss_counter @@ -2617,6 +2655,23 @@ def _prefetch( if self.should_log(): self.print_uvm_cache_stats(use_local_cache=False) + if self.enable_raw_embedding_streaming: + with record_function( + "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid) + ): + ( + linear_unique_indices, + linear_unique_indices_length, + _, + ) = torch.ops.fbgemm.get_unique_indices( + linear_cache_indices_merged, + self.total_cache_hash_size, + compute_count=False, + ) + self.prefetched_info.append( + (linear_unique_indices, linear_unique_indices_length) + ) + def should_log(self) -> bool: """Determines if we should log for this step, using exponentially decreasing frequency. @@ -3829,6 +3884,55 @@ def _debug_print_input_stats_factory_null( return _debug_print_input_stats_factory_impl return _debug_print_input_stats_factory_null + @torch.jit.ignore + def raw_embedding_stream(self) -> None: + if not self.enable_raw_embedding_streaming: + return None + # when pipelining is enabled + # prefetch in iter i happens before the backward sparse in iter i - 1 + # so embeddings for iter i - 1's changed ids are not updated. + # so we can only fetch the indices from the iter i - 2 + # when pipelining is disabled + # prefetch in iter i happens before forward iter i + # so we can get the iter i - 1's changed ids safely. + target_prev_iter = 1 + if self.prefetch_pipeline: + target_prev_iter = 2 + if not len(self.prefetched_info) > (target_prev_iter - 1): + return None + with record_function( + "## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid) + ): + (updated_indices, updated_count) = self.prefetched_info.pop(0) + updated_locations = torch.ops.fbgemm.lxu_cache_lookup( + updated_indices, + self.lxu_cache_state, + self.total_cache_hash_size, + gather_cache_stats=False, # not collecting cache stats + num_uniq_cache_indices=updated_count, + ) + updated_weights = torch.empty( + [updated_indices.size()[0], self.max_D_cache], + # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]` + dtype=self.lxu_cache_weights.dtype, + # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]` + device=self.lxu_cache_weights.device, + ) + torch.ops.fbgemm.masked_index_select( + updated_weights, + updated_locations, + self.lxu_cache_weights, + updated_count, + ) + # stream weights + self._raw_embedding_streamer.stream( + updated_indices.to(device=torch.device("cpu")), + updated_weights.to(device=torch.device("cpu")), + updated_count.to(device=torch.device("cpu")), + False, # require_tensor_copy + False, # blocking_tensor_copy + ) + class DenseTableBatchedEmbeddingBagsCodegen(nn.Module): """ diff --git a/fbgemm_gpu/test/tbe/training/forward_test.py b/fbgemm_gpu/test/tbe/training/forward_test.py index 20bd998b6d..2109ad30c1 100644 --- a/fbgemm_gpu/test/tbe/training/forward_test.py +++ b/fbgemm_gpu/test/tbe/training/forward_test.py @@ -12,6 +12,7 @@ import math import random import unittest +from unittest.mock import MagicMock, patch import hypothesis.strategies as st import numpy as np @@ -24,6 +25,7 @@ ) from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, + RESParams, SplitTableBatchedEmbeddingBagsCodegen, ) from fbgemm_gpu.tbe.utils import ( @@ -129,6 +131,8 @@ def execute_forward_( # noqa C901 use_cpu: bool, output_dtype: SparseType, use_experimental_tbe: bool, + enable_raw_embedding_streaming: bool = False, + prefetch_pipeline: bool = False, ) -> None: # NOTE: cache is not applicable to CPU version. assume(not use_cpu or not use_cache) @@ -158,6 +162,10 @@ def execute_forward_( # noqa C901 and pooling_mode != PoolingMode.NONE ) ) + # NOTE: Raw embedding streaming requires UVM cache + assume(not enable_raw_embedding_streaming or use_cache) + # NOTE: Raw embedding streaming not supported on CPU + assume(not enable_raw_embedding_streaming or not use_cpu) emb_op = SplitTableBatchedEmbeddingBagsCodegen if pooling_mode == PoolingMode.SUM: @@ -285,6 +293,16 @@ def execute_forward_( # noqa C901 else: f = torch.cat(fs, dim=0).view(-1, D) + # Create RES parameters if raw embedding streaming is enabled + res_params = None + if enable_raw_embedding_streaming: + res_params = RESParams( + res_store_shards=1, + table_names=[f"table_{i}" for i in range(T)], + table_offsets=[sum(Es[:i]) for i in range(T + 1)], + table_sizes=Es, + ) + # Create a TBE op cc = emb_op( embedding_specs=[ @@ -305,6 +323,9 @@ def execute_forward_( # noqa C901 pooling_mode=pooling_mode, output_dtype=output_dtype, use_experimental_tbe=use_experimental_tbe, + prefetch_pipeline=prefetch_pipeline, + enable_raw_embedding_streaming=enable_raw_embedding_streaming, + res_params=res_params, ) # Test torch JIT script compatibility if not use_cpu: @@ -1158,6 +1179,96 @@ def test_forward_fused_pooled_emb_quant( cat_deq_lowp_pooled_output, cat_dq_fp32_pooled_output ) + def _check_raw_embedding_stream_call_counts( + self, + mock_raw_embedding_stream: unittest.mock.Mock, + num_iterations: int, + prefetch_pipeline: bool, + L: int, + ) -> None: + # For TBE (not SSD), raw_embedding_stream is called once per prefetch + # when there's data to stream + expected_calls = num_iterations if L > 0 else 0 + if prefetch_pipeline: + # With prefetch pipeline, there might be fewer calls initially + expected_calls = max(0, expected_calls - 1) + + self.assertGreaterEqual(mock_raw_embedding_stream.call_count, 0) + # Allow some flexibility in call count due to caching behavior + self.assertLessEqual(mock_raw_embedding_stream.call_count, expected_calls + 2) + + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=64), + B=st.integers(min_value=1, max_value=32), + log_E=st.integers(min_value=3, max_value=4), + L=st.integers(min_value=1, max_value=10), + weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), + cache_algorithm=st.sampled_from(CacheAlgorithm), + pooling_mode=st.sampled_from([PoolingMode.SUM, PoolingMode.MEAN]), + weighted=st.booleans(), + mixed=st.booleans(), + prefetch_pipeline=st.booleans(), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_forward_raw_embedding_streaming( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + cache_algorithm: CacheAlgorithm, + pooling_mode: PoolingMode, + weighted: bool, + mixed: bool, + prefetch_pipeline: bool, + ) -> None: + """Test raw embedding streaming functionality integrated with forward pass.""" + num_iterations = 5 + # only LRU supports prefetch_pipeline + assume(not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU) + + with patch( + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" + ) as mock_streamer_class: + # Mock the RawEmbeddingStreamer class + mock_streamer_instance = MagicMock() + mock_streamer_class.return_value = mock_streamer_instance + + # Run multiple iterations to test streaming behavior + for _ in range(num_iterations): + self.execute_forward_( + T=T, + D=D, + B=B, + log_E=log_E, + L=L, + weights_precision=weights_precision, + weighted=weighted, + mixed=mixed, + mixed_B=False, # Keep simple for streaming tests + use_cache=True, # Required for streaming + cache_algorithm=cache_algorithm, + pooling_mode=pooling_mode, + use_cpu=False, # Streaming not supported on CPU + output_dtype=SparseType.FP32, + use_experimental_tbe=False, + enable_raw_embedding_streaming=True, + prefetch_pipeline=prefetch_pipeline, + ) + + self._check_raw_embedding_stream_call_counts( + mock_streamer_instance, num_iterations, prefetch_pipeline, L + ) + if __name__ == "__main__": unittest.main()